aboutsummaryrefslogtreecommitdiff
path: root/src/common/scram-common.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/scram-common.c')
-rw-r--r--src/common/scram-common.c84
1 files changed, 48 insertions, 36 deletions
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 12686259299..bffbbb43172 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,6 +33,7 @@
*/
int
scram_SaltedPassword(const char *password,
+ pg_cryptohash_type hash_type, int key_length,
const char *salt, int saltlen, int iterations,
uint8 *result, const char **errstr)
{
@@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password,
uint32 one = pg_hton32(1);
int i,
j;
- uint8 Ui[SCRAM_KEY_LEN];
- uint8 Ui_prev[SCRAM_KEY_LEN];
- pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
+ uint8 Ui[SCRAM_MAX_KEY_LEN];
+ uint8 Ui_prev[SCRAM_MAX_KEY_LEN];
+ pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
if (hmac_ctx == NULL)
{
@@ -60,30 +61,30 @@ scram_SaltedPassword(const char *password,
if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
- pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
+ pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
{
*errstr = pg_hmac_error(hmac_ctx);
pg_hmac_free(hmac_ctx);
return -1;
}
- memcpy(result, Ui_prev, SCRAM_KEY_LEN);
+ memcpy(result, Ui_prev, key_length);
/* Subsequent iterations */
for (i = 2; i <= iterations; i++)
{
if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
- pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
- pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
+ pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
+ pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
{
*errstr = pg_hmac_error(hmac_ctx);
pg_hmac_free(hmac_ctx);
return -1;
}
- for (j = 0; j < SCRAM_KEY_LEN; j++)
+ for (j = 0; j < key_length; j++)
result[j] ^= Ui[j];
- memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
+ memcpy(Ui_prev, Ui, key_length);
}
pg_hmac_free(hmac_ctx);
@@ -92,16 +93,17 @@ scram_SaltedPassword(const char *password,
/*
- * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
+ * Calculate hash for a NULL-terminated string. (The NULL terminator is
* not included in the hash). Returns 0 on success, -1 on failure with *errstr
* pointing to a message about the error details.
*/
int
-scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
+scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
{
pg_cryptohash_ctx *ctx;
- ctx = pg_cryptohash_create(PG_SHA256);
+ ctx = pg_cryptohash_create(hash_type);
if (ctx == NULL)
{
*errstr = pg_cryptohash_error(NULL); /* returns OOM */
@@ -109,8 +111,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
}
if (pg_cryptohash_init(ctx) < 0 ||
- pg_cryptohash_update(ctx, input, len) < 0 ||
- pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
+ pg_cryptohash_update(ctx, input, key_length) < 0 ||
+ pg_cryptohash_final(ctx, result, key_length) < 0)
{
*errstr = pg_cryptohash_error(ctx);
pg_cryptohash_free(ctx);
@@ -126,10 +128,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
* pointing to a message about the error details.
*/
int
-scram_ClientKey(const uint8 *salted_password, uint8 *result,
- const char **errstr)
+scram_ClientKey(const uint8 *salted_password,
+ pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
{
- pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+ pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
if (ctx == NULL)
{
@@ -137,9 +140,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
return -1;
}
- if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+ if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
- pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+ pg_hmac_final(ctx, result, key_length) < 0)
{
*errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx);
@@ -155,10 +158,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
* pointing to a message about the error details.
*/
int
-scram_ServerKey(const uint8 *salted_password, uint8 *result,
- const char **errstr)
+scram_ServerKey(const uint8 *salted_password,
+ pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
{
- pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+ pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
if (ctx == NULL)
{
@@ -166,9 +170,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
return -1;
}
- if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+ if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
- pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+ pg_hmac_final(ctx, result, key_length) < 0)
{
*errstr = pg_hmac_error(ctx);
pg_hmac_free(ctx);
@@ -192,12 +196,13 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
* error details.
*/
char *
-scram_build_secret(const char *salt, int saltlen, int iterations,
+scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+ const char *salt, int saltlen, int iterations,
const char *password, const char **errstr)
{
- uint8 salted_password[SCRAM_KEY_LEN];
- uint8 stored_key[SCRAM_KEY_LEN];
- uint8 server_key[SCRAM_KEY_LEN];
+ uint8 salted_password[SCRAM_MAX_KEY_LEN];
+ uint8 stored_key[SCRAM_MAX_KEY_LEN];
+ uint8 server_key[SCRAM_MAX_KEY_LEN];
char *result;
char *p;
int maxlen;
@@ -206,15 +211,22 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
int encoded_server_len;
int encoded_result;
+ /* Only this hash method is supported currently */
+ Assert(hash_type == PG_SHA256);
+
if (iterations <= 0)
iterations = SCRAM_DEFAULT_ITERATIONS;
/* Calculate StoredKey and ServerKey */
- if (scram_SaltedPassword(password, salt, saltlen, iterations,
+ if (scram_SaltedPassword(password, hash_type, key_length,
+ salt, saltlen, iterations,
salted_password, errstr) < 0 ||
- scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
- scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
- scram_ServerKey(salted_password, server_key, errstr) < 0)
+ scram_ClientKey(salted_password, hash_type, key_length,
+ stored_key, errstr) < 0 ||
+ scram_H(stored_key, hash_type, key_length,
+ stored_key, errstr) < 0 ||
+ scram_ServerKey(salted_password, hash_type, key_length,
+ server_key, errstr) < 0)
{
/* errstr is filled already here */
#ifdef FRONTEND
@@ -231,8 +243,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*----------
*/
encoded_salt_len = pg_b64_enc_len(saltlen);
- encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
- encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+ encoded_stored_len = pg_b64_enc_len(key_length);
+ encoded_server_len = pg_b64_enc_len(key_length);
maxlen = strlen("SCRAM-SHA-256") + 1
+ 10 + 1 /* iteration count */
@@ -269,7 +281,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*(p++) = '$';
/* stored key */
- encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+ encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
encoded_stored_len);
if (encoded_result < 0)
{
@@ -286,7 +298,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
*(p++) = ':';
/* server key */
- encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+ encoded_result = pg_b64_encode((char *) server_key, key_length, p,
encoded_server_len);
if (encoded_result < 0)
{