aboutsummaryrefslogtreecommitdiff
path: root/src/backend/libpq/auth.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/libpq/auth.c')
-rw-r--r--src/backend/libpq/auth.c54
1 files changed, 43 insertions, 11 deletions
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 6c915a72890..2dd3328d71e 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -860,6 +860,8 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
static int
CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
{
+ char *sasl_mechs;
+ char *p;
int mtype;
StringInfoData buf;
void *scram_opaq;
@@ -869,6 +871,8 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
int inputlen;
int result;
bool initial;
+ char *tls_finished = NULL;
+ size_t tls_finished_len = 0;
/*
* SASL auth is not supported for protocol versions before 3, because it
@@ -885,12 +889,39 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
/*
* Send the SASL authentication request to user. It includes the list of
- * authentication mechanisms (which is trivial, because we only support
- * SCRAM-SHA-256 at the moment). The extra "\0" is for an empty string to
- * terminate the list.
+ * authentication mechanisms that are supported. The order of mechanisms
+ * is advertised in decreasing order of importance. So the
+ * channel-binding variants go first, if they are supported. Channel
+ * binding is only supported in SSL builds.
*/
- sendAuthRequest(port, AUTH_REQ_SASL, SCRAM_SHA256_NAME "\0",
- strlen(SCRAM_SHA256_NAME) + 2);
+ sasl_mechs = palloc(strlen(SCRAM_SHA256_PLUS_NAME) +
+ strlen(SCRAM_SHA256_NAME) + 3);
+ p = sasl_mechs;
+
+ if (port->ssl_in_use)
+ {
+ strcpy(p, SCRAM_SHA256_PLUS_NAME);
+ p += strlen(SCRAM_SHA256_PLUS_NAME) + 1;
+ }
+
+ strcpy(p, SCRAM_SHA256_NAME);
+ p += strlen(SCRAM_SHA256_NAME) + 1;
+
+ /* Put another '\0' to mark that list is finished. */
+ p[0] = '\0';
+
+ sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1);
+ pfree(sasl_mechs);
+
+#ifdef USE_SSL
+ /*
+ * Get data for channel binding.
+ */
+ if (port->ssl_in_use)
+ {
+ tls_finished = be_tls_get_peer_finished(port, &tls_finished_len);
+ }
+#endif
/*
* Initialize the status tracker for message exchanges.
@@ -903,7 +934,11 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
* This is because we don't want to reveal to an attacker what usernames
* are valid, nor which users have a valid password.
*/
- scram_opaq = pg_be_scram_init(port->user_name, shadow_pass);
+ scram_opaq = pg_be_scram_init(port->user_name,
+ shadow_pass,
+ port->ssl_in_use,
+ tls_finished,
+ tls_finished_len);
/*
* Loop through SASL message exchange. This exchange can consist of
@@ -951,12 +986,9 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
{
const char *selected_mech;
- /*
- * We only support SCRAM-SHA-256 at the moment, so anything else
- * is an error.
- */
selected_mech = pq_getmsgrawstring(&buf);
- if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0)
+ if (strcmp(selected_mech, SCRAM_SHA256_NAME) != 0 &&
+ strcmp(selected_mech, SCRAM_SHA256_PLUS_NAME) != 0)
{
ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),