diff options
author | Michael Paquier <michael@paquier.xyz> | 2019-07-04 16:08:09 +0900 |
---|---|---|
committer | Michael Paquier <michael@paquier.xyz> | 2019-07-04 16:08:09 +0900 |
commit | cfc40d384ae51ea2886d599d2008ae57b529e6ea (patch) | |
tree | 725bf1bb55c99ead091f16ff9ccfc542ef7a2855 /src/backend/libpq/auth-scram.c | |
parent | d5ab9a891cb590aad4278026b2edda685f2524a2 (diff) | |
download | postgresql-cfc40d384ae51ea2886d599d2008ae57b529e6ea.tar.gz postgresql-cfc40d384ae51ea2886d599d2008ae57b529e6ea.zip |
Introduce safer encoding and decoding routines for base64.c
This is a follow-up refactoring after 09ec55b and b674211, which has
proved that the encoding and decoding routines used by SCRAM have a
poor interface when it comes to check after buffer overflows. This adds
an extra argument in the shape of the length of the result buffer for
each routine, which is used for overflow checks when encoding or
decoding an input string. The original idea comes from Tom Lane.
As a result of that, the encoding routine can now fail, so all its
callers are adjusted to generate proper error messages in case of
problems.
On failure, the result buffer gets zeroed.
Author: Michael Paquier
Reviewed-by: Daniel Gustafsson
Discussion: https://postgr.es/m/20190623132535.GB1628@paquier.xyz
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r-- | src/backend/libpq/auth-scram.c | 74 |
1 files changed, 55 insertions, 19 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 6b60abe1ddc..aa918839fb9 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password, return false; } - salt = palloc(pg_b64_dec_len(strlen(encoded_salt))); - saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt); - if (saltlen == -1) + saltlen = pg_b64_dec_len(strlen(encoded_salt)); + salt = palloc(saltlen); + saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt, + saltlen); + if (saltlen < 0) { ereport(LOG, (errmsg("invalid SCRAM verifier for user \"%s\"", username))); @@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, * Verify that the salt is in Base64-encoded format, by decoding it, * although we return the encoded version to the caller. */ - decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str))); + decoded_len = pg_b64_dec_len(strlen(salt_str)); + decoded_salt_buf = palloc(decoded_len); decoded_len = pg_b64_decode(salt_str, strlen(salt_str), - decoded_salt_buf); + decoded_salt_buf, decoded_len); if (decoded_len < 0) goto invalid_verifier; *salt = pstrdup(salt_str); @@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, /* * Decode StoredKey and ServerKey. */ - decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str))); + decoded_len = pg_b64_dec_len(strlen(storedkey_str)); + decoded_stored_buf = palloc(decoded_len); decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), - decoded_stored_buf); + decoded_stored_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); - decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str))); + decoded_len = pg_b64_dec_len(strlen(serverkey_str)); + decoded_server_buf = palloc(decoded_len); decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), - decoded_server_buf); + decoded_server_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); @@ -649,8 +654,20 @@ mock_scram_verifier(const char *username, int *iterations, char **salt, /* Generate deterministic salt */ raw_salt = scram_mock_salt(username); - encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1); - encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt); + encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN); + /* don't forget the zero-terminator */ + encoded_salt = (char *) palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt, + encoded_len); + + /* + * Note that we cannot reveal any information to an attacker here so the + * error message needs to remain generic. This should never fail anyway + * as the salt generated for mock authentication uses the cluster's nonce + * value. + */ + if (encoded_len < 0) + elog(ERROR, "could not encode salt"); encoded_salt[encoded_len] = '\0'; *salt = encoded_salt; @@ -1144,8 +1161,15 @@ build_server_first_message(scram_state *state) (errcode(ERRCODE_INTERNAL_ERROR), errmsg("could not generate random nonce"))); - state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1); - encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce); + encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN); + /* don't forget the zero-terminator */ + state->server_nonce = palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, + state->server_nonce, encoded_len); + if (encoded_len < 0) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not encode random nonce"))); state->server_nonce[encoded_len] = '\0'; state->server_first_message = @@ -1170,6 +1194,7 @@ read_client_final_message(scram_state *state, const char *input) *proof; char *p; char *client_proof; + int client_proof_len; begin = p = pstrdup(input); @@ -1234,9 +1259,13 @@ read_client_final_message(scram_state *state, const char *input) snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,"); memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); - b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1); + b64_message_len = pg_b64_enc_len(cbind_input_len); + /* don't forget the zero-terminator */ + b64_message = palloc(b64_message_len + 1); b64_message_len = pg_b64_encode(cbind_input, cbind_input_len, - b64_message); + b64_message, b64_message_len); + if (b64_message_len < 0) + elog(ERROR, "could not encode channel binding data"); b64_message[b64_message_len] = '\0'; /* @@ -1276,8 +1305,10 @@ read_client_final_message(scram_state *state, const char *input) value = read_any_attr(&p, &attr); } while (attr != 'p'); - client_proof = palloc(pg_b64_dec_len(strlen(value))); - if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN) + client_proof_len = pg_b64_dec_len(strlen(value)); + client_proof = palloc(client_proof_len); + if (pg_b64_decode(value, strlen(value), client_proof, + client_proof_len) != SCRAM_KEY_LEN) ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed SCRAM message"), @@ -1322,9 +1353,14 @@ build_server_final_message(scram_state *state) strlen(state->client_final_message_without_proof)); scram_HMAC_final(ServerSignature, &ctx); - server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); + siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + /* don't forget the zero-terminator */ + server_signature_base64 = palloc(siglen + 1); siglen = pg_b64_encode((const char *) ServerSignature, - SCRAM_KEY_LEN, server_signature_base64); + SCRAM_KEY_LEN, server_signature_base64, + siglen); + if (siglen < 0) + elog(ERROR, "could not encode server signature"); server_signature_base64[siglen] = '\0'; /*------ |