diff options
Diffstat (limited to 'src/backend/libpq/auth.c')
-rw-r--r-- | src/backend/libpq/auth.c | 54 |
1 files changed, 43 insertions, 11 deletions
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c index 6c915a72890..2dd3328d71e 100644 --- a/src/backend/libpq/auth.c +++ b/src/backend/libpq/auth.c @@ -860,6 +860,8 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail) static int CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) { + char *sasl_mechs; + char *p; int mtype; StringInfoData buf; void *scram_opaq; @@ -869,6 +871,8 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) int inputlen; int result; bool initial; + char *tls_finished = NULL; + size_t tls_finished_len = 0; /* * SASL auth is not supported for protocol versions before 3, because it @@ -885,12 +889,39 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) /* * Send the SASL authentication request to user. It includes the list of - * authentication mechanisms (which is trivial, because we only support - * SCRAM-SHA-256 at the moment). The extra "\0" is for an empty string to - * terminate the list. + * authentication mechanisms that are supported. The order of mechanisms + * is advertised in decreasing order of importance. So the + * channel-binding variants go first, if they are supported. Channel + * binding is only supported in SSL builds. */ - sendAuthRequest(port, AUTH_REQ_SASL, SCRAM_SHA256_NAME "\0", - strlen(SCRAM_SHA256_NAME) + 2); + sasl_mechs = palloc(strlen(SCRAM_SHA256_PLUS_NAME) + + strlen(SCRAM_SHA256_NAME) + 3); + p = sasl_mechs; + + if (port->ssl_in_use) + { + strcpy(p, SCRAM_SHA256_PLUS_NAME); + p += strlen(SCRAM_SHA256_PLUS_NAME) + 1; + } + + strcpy(p, SCRAM_SHA256_NAME); + p += strlen(SCRAM_SHA256_NAME) + 1; + + /* Put another '\0' to mark that list is finished. */ + p[0] = '\0'; + + sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1); + pfree(sasl_mechs); + +#ifdef USE_SSL + /* + * Get data for channel binding. + */ + if (port->ssl_in_use) + { + tls_finished = be_tls_get_peer_finished(port, &tls_finished_len); + } +#endif /* * Initialize the status tracker for message exchanges. @@ -903,7 +934,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) * This is because we don't want to reveal to an attacker what usernames * are valid, nor which users have a valid password. */ - scram_opaq = pg_be_scram_init(port->user_name, shadow_pass); + scram_opaq = pg_be_scram_init(port->user_name, + shadow_pass, + port->ssl_in_use, + tls_finished, + tls_finished_len); /* * Loop through SASL message exchange. This exchange can consist of @@ -951,12 +986,9 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) { const char *selected_mech; - /* - * We only support SCRAM-SHA-256 at the moment, so anything else - * is an error. - */ selected_mech = pq_getmsgrawstring(&buf); - if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0) + if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0 && + strcmp(selected_mech, SCRAM_SHA256_PLUS_NAME) != 0) { ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), |