diff options
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r-- | src/backend/libpq/auth-scram.c | 74 |
1 files changed, 55 insertions, 19 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 6b60abe1ddc..aa918839fb9 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password, return false; } - salt = palloc(pg_b64_dec_len(strlen(encoded_salt))); - saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt); - if (saltlen == -1) + saltlen = pg_b64_dec_len(strlen(encoded_salt)); + salt = palloc(saltlen); + saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt, + saltlen); + if (saltlen < 0) { ereport(LOG, (errmsg("invalid SCRAM verifier for user \"%s\"", username))); @@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, * Verify that the salt is in Base64-encoded format, by decoding it, * although we return the encoded version to the caller. */ - decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str))); + decoded_len = pg_b64_dec_len(strlen(salt_str)); + decoded_salt_buf = palloc(decoded_len); decoded_len = pg_b64_decode(salt_str, strlen(salt_str), - decoded_salt_buf); + decoded_salt_buf, decoded_len); if (decoded_len < 0) goto invalid_verifier; *salt = pstrdup(salt_str); @@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, /* * Decode StoredKey and ServerKey. */ - decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str))); + decoded_len = pg_b64_dec_len(strlen(storedkey_str)); + decoded_stored_buf = palloc(decoded_len); decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), - decoded_stored_buf); + decoded_stored_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); - decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str))); + decoded_len = pg_b64_dec_len(strlen(serverkey_str)); + decoded_server_buf = palloc(decoded_len); decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), - decoded_server_buf); + decoded_server_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); @@ -649,8 +654,20 @@ mock_scram_verifier(const char *username, int *iterations, char **salt, /* Generate deterministic salt */ raw_salt = scram_mock_salt(username); - encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1); - encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt); + encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN); + /* don't forget the zero-terminator */ + encoded_salt = (char *) palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt, + encoded_len); + + /* + * Note that we cannot reveal any information to an attacker here so the + * error message needs to remain generic. This should never fail anyway + * as the salt generated for mock authentication uses the cluster's nonce + * value. + */ + if (encoded_len < 0) + elog(ERROR, "could not encode salt"); encoded_salt[encoded_len] = '\0'; *salt = encoded_salt; @@ -1144,8 +1161,15 @@ build_server_first_message(scram_state *state) (errcode(ERRCODE_INTERNAL_ERROR), errmsg("could not generate random nonce"))); - state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1); - encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce); + encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN); + /* don't forget the zero-terminator */ + state->server_nonce = palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, + state->server_nonce, encoded_len); + if (encoded_len < 0) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not encode random nonce"))); state->server_nonce[encoded_len] = '\0'; state->server_first_message = @@ -1170,6 +1194,7 @@ read_client_final_message(scram_state *state, const char *input) *proof; char *p; char *client_proof; + int client_proof_len; begin = p = pstrdup(input); @@ -1234,9 +1259,13 @@ read_client_final_message(scram_state *state, const char *input) snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,"); memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); - b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1); + b64_message_len = pg_b64_enc_len(cbind_input_len); + /* don't forget the zero-terminator */ + b64_message = palloc(b64_message_len + 1); b64_message_len = pg_b64_encode(cbind_input, cbind_input_len, - b64_message); + b64_message, b64_message_len); + if (b64_message_len < 0) + elog(ERROR, "could not encode channel binding data"); b64_message[b64_message_len] = '\0'; /* @@ -1276,8 +1305,10 @@ read_client_final_message(scram_state *state, const char *input) value = read_any_attr(&p, &attr); } while (attr != 'p'); - client_proof = palloc(pg_b64_dec_len(strlen(value))); - if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN) + client_proof_len = pg_b64_dec_len(strlen(value)); + client_proof = palloc(client_proof_len); + if (pg_b64_decode(value, strlen(value), client_proof, + client_proof_len) != SCRAM_KEY_LEN) ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed SCRAM message"), @@ -1322,9 +1353,14 @@ build_server_final_message(scram_state *state) strlen(state->client_final_message_without_proof)); scram_HMAC_final(ServerSignature, &ctx); - server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); + siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + /* don't forget the zero-terminator */ + server_signature_base64 = palloc(siglen + 1); siglen = pg_b64_encode((const char *) ServerSignature, - SCRAM_KEY_LEN, server_signature_base64); + SCRAM_KEY_LEN, server_signature_base64, + siglen); + if (siglen < 0) + elog(ERROR, "could not encode server signature"); server_signature_base64[siglen] = '\0'; /*------ |