aboutsummaryrefslogtreecommitdiff
path: root/src/backend/libpq/auth-scram.c
diff options
context:
space:
mode:
authorMichael Paquier <michael@paquier.xyz>2019-07-04 16:08:09 +0900
committerMichael Paquier <michael@paquier.xyz>2019-07-04 16:08:09 +0900
commitcfc40d384ae51ea2886d599d2008ae57b529e6ea (patch)
tree725bf1bb55c99ead091f16ff9ccfc542ef7a2855 /src/backend/libpq/auth-scram.c
parentd5ab9a891cb590aad4278026b2edda685f2524a2 (diff)
downloadpostgresql-cfc40d384ae51ea2886d599d2008ae57b529e6ea.tar.gz
postgresql-cfc40d384ae51ea2886d599d2008ae57b529e6ea.zip
Introduce safer encoding and decoding routines for base64.c
This is a follow-up refactoring after 09ec55b and b674211, which has proved that the encoding and decoding routines used by SCRAM have a poor interface when it comes to check after buffer overflows. This adds an extra argument in the shape of the length of the result buffer for each routine, which is used for overflow checks when encoding or decoding an input string. The original idea comes from Tom Lane. As a result of that, the encoding routine can now fail, so all its callers are adjusted to generate proper error messages in case of problems. On failure, the result buffer gets zeroed. Author: Michael Paquier Reviewed-by: Daniel Gustafsson Discussion: https://postgr.es/m/20190623132535.GB1628@paquier.xyz
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r--src/backend/libpq/auth-scram.c74
1 files changed, 55 insertions, 19 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index 6b60abe1ddc..aa918839fb9 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password,
return false;
}
- salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
- saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
- if (saltlen == -1)
+ saltlen = pg_b64_dec_len(strlen(encoded_salt));
+ salt = palloc(saltlen);
+ saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
+ saltlen);
+ if (saltlen < 0)
{
ereport(LOG,
(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
@@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
* Verify that the salt is in Base64-encoded format, by decoding it,
* although we return the encoded version to the caller.
*/
- decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
+ decoded_len = pg_b64_dec_len(strlen(salt_str));
+ decoded_salt_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
- decoded_salt_buf);
+ decoded_salt_buf, decoded_len);
if (decoded_len < 0)
goto invalid_verifier;
*salt = pstrdup(salt_str);
@@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt,
/*
* Decode StoredKey and ServerKey.
*/
- decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str)));
+ decoded_len = pg_b64_dec_len(strlen(storedkey_str));
+ decoded_stored_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
- decoded_stored_buf);
+ decoded_stored_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN)
goto invalid_verifier;
memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
- decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str)));
+ decoded_len = pg_b64_dec_len(strlen(serverkey_str));
+ decoded_server_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
- decoded_server_buf);
+ decoded_server_buf, decoded_len);
if (decoded_len != SCRAM_KEY_LEN)
goto invalid_verifier;
memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
@@ -649,8 +654,20 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
/* Generate deterministic salt */
raw_salt = scram_mock_salt(username);
- encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
- encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
+ encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
+ /* don't forget the zero-terminator */
+ encoded_salt = (char *) palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
+ encoded_len);
+
+ /*
+ * Note that we cannot reveal any information to an attacker here so the
+ * error message needs to remain generic. This should never fail anyway
+ * as the salt generated for mock authentication uses the cluster's nonce
+ * value.
+ */
+ if (encoded_len < 0)
+ elog(ERROR, "could not encode salt");
encoded_salt[encoded_len] = '\0';
*salt = encoded_salt;
@@ -1144,8 +1161,15 @@ build_server_first_message(scram_state *state)
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("could not generate random nonce")));
- state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
- encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce);
+ encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+ /* don't forget the zero-terminator */
+ state->server_nonce = palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+ state->server_nonce, encoded_len);
+ if (encoded_len < 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("could not encode random nonce")));
state->server_nonce[encoded_len] = '\0';
state->server_first_message =
@@ -1170,6 +1194,7 @@ read_client_final_message(scram_state *state, const char *input)
*proof;
char *p;
char *client_proof;
+ int client_proof_len;
begin = p = pstrdup(input);
@@ -1234,9 +1259,13 @@ read_client_final_message(scram_state *state, const char *input)
snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
- b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
+ b64_message_len = pg_b64_enc_len(cbind_input_len);
+ /* don't forget the zero-terminator */
+ b64_message = palloc(b64_message_len + 1);
b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
- b64_message);
+ b64_message, b64_message_len);
+ if (b64_message_len < 0)
+ elog(ERROR, "could not encode channel binding data");
b64_message[b64_message_len] = '\0';
/*
@@ -1276,8 +1305,10 @@ read_client_final_message(scram_state *state, const char *input)
value = read_any_attr(&p, &attr);
} while (attr != 'p');
- client_proof = palloc(pg_b64_dec_len(strlen(value)));
- if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
+ client_proof_len = pg_b64_dec_len(strlen(value));
+ client_proof = palloc(client_proof_len);
+ if (pg_b64_decode(value, strlen(value), client_proof,
+ client_proof_len) != SCRAM_KEY_LEN)
ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("malformed SCRAM message"),
@@ -1322,9 +1353,14 @@ build_server_final_message(scram_state *state)
strlen(state->client_final_message_without_proof));
scram_HMAC_final(ServerSignature, &ctx);
- server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
+ siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+ /* don't forget the zero-terminator */
+ server_signature_base64 = palloc(siglen + 1);
siglen = pg_b64_encode((const char *) ServerSignature,
- SCRAM_KEY_LEN, server_signature_base64);
+ SCRAM_KEY_LEN, server_signature_base64,
+ siglen);
+ if (siglen < 0)
+ elog(ERROR, "could not encode server signature");
server_signature_base64[siglen] = '\0';
/*------