diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index 6b60abe1dd..aa918839fb 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -510,9 +510,11 @@ scram_verify_plain_password(const char *username, const char *password, return false; } - salt = palloc(pg_b64_dec_len(strlen(encoded_salt))); - saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt); - if (saltlen == -1) + saltlen = pg_b64_dec_len(strlen(encoded_salt)); + salt = palloc(saltlen); + saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt, + saltlen); + if (saltlen < 0) { ereport(LOG, (errmsg("invalid SCRAM verifier for user \"%s\"", username))); @@ -596,9 +598,10 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, * Verify that the salt is in Base64-encoded format, by decoding it, * although we return the encoded version to the caller. */ - decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str))); + decoded_len = pg_b64_dec_len(strlen(salt_str)); + decoded_salt_buf = palloc(decoded_len); decoded_len = pg_b64_decode(salt_str, strlen(salt_str), - decoded_salt_buf); + decoded_salt_buf, decoded_len); if (decoded_len < 0) goto invalid_verifier; *salt = pstrdup(salt_str); @@ -606,16 +609,18 @@ parse_scram_verifier(const char *verifier, int *iterations, char **salt, /* * Decode StoredKey and ServerKey. */ - decoded_stored_buf = palloc(pg_b64_dec_len(strlen(storedkey_str))); + decoded_len = pg_b64_dec_len(strlen(storedkey_str)); + decoded_stored_buf = palloc(decoded_len); decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), - decoded_stored_buf); + decoded_stored_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); - decoded_server_buf = palloc(pg_b64_dec_len(strlen(serverkey_str))); + decoded_len = pg_b64_dec_len(strlen(serverkey_str)); + decoded_server_buf = palloc(decoded_len); decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), - decoded_server_buf); + decoded_server_buf, decoded_len); if (decoded_len != SCRAM_KEY_LEN) goto invalid_verifier; memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); @@ -649,8 +654,20 @@ mock_scram_verifier(const char *username, int *iterations, char **salt, /* Generate deterministic salt */ raw_salt = scram_mock_salt(username); - encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1); - encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt); + encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN); + /* don't forget the zero-terminator */ + encoded_salt = (char *) palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt, + encoded_len); + + /* + * Note that we cannot reveal any information to an attacker here so the + * error message needs to remain generic. This should never fail anyway + * as the salt generated for mock authentication uses the cluster's nonce + * value. + */ + if (encoded_len < 0) + elog(ERROR, "could not encode salt"); encoded_salt[encoded_len] = '\0'; *salt = encoded_salt; @@ -1144,8 +1161,15 @@ build_server_first_message(scram_state *state) (errcode(ERRCODE_INTERNAL_ERROR), errmsg("could not generate random nonce"))); - state->server_nonce = palloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1); - encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->server_nonce); + encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN); + /* don't forget the zero-terminator */ + state->server_nonce = palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, + state->server_nonce, encoded_len); + if (encoded_len < 0) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not encode random nonce"))); state->server_nonce[encoded_len] = '\0'; state->server_first_message = @@ -1170,6 +1194,7 @@ read_client_final_message(scram_state *state, const char *input) *proof; char *p; char *client_proof; + int client_proof_len; begin = p = pstrdup(input); @@ -1234,9 +1259,13 @@ read_client_final_message(scram_state *state, const char *input) snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,"); memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); - b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1); + b64_message_len = pg_b64_enc_len(cbind_input_len); + /* don't forget the zero-terminator */ + b64_message = palloc(b64_message_len + 1); b64_message_len = pg_b64_encode(cbind_input, cbind_input_len, - b64_message); + b64_message, b64_message_len); + if (b64_message_len < 0) + elog(ERROR, "could not encode channel binding data"); b64_message[b64_message_len] = '\0'; /* @@ -1276,8 +1305,10 @@ read_client_final_message(scram_state *state, const char *input) value = read_any_attr(&p, &attr); } while (attr != 'p'); - client_proof = palloc(pg_b64_dec_len(strlen(value))); - if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN) + client_proof_len = pg_b64_dec_len(strlen(value)); + client_proof = palloc(client_proof_len); + if (pg_b64_decode(value, strlen(value), client_proof, + client_proof_len) != SCRAM_KEY_LEN) ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed SCRAM message"), @@ -1322,9 +1353,14 @@ build_server_final_message(scram_state *state) strlen(state->client_final_message_without_proof)); scram_HMAC_final(ServerSignature, &ctx); - server_signature_base64 = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); + siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + /* don't forget the zero-terminator */ + server_signature_base64 = palloc(siglen + 1); siglen = pg_b64_encode((const char *) ServerSignature, - SCRAM_KEY_LEN, server_signature_base64); + SCRAM_KEY_LEN, server_signature_base64, + siglen); + if (siglen < 0) + elog(ERROR, "could not encode server signature"); server_signature_base64[siglen] = '\0'; /*------ diff --git a/src/common/base64.c b/src/common/base64.c index 55c8983f97..57ec06c3a9 100644 --- a/src/common/base64.c +++ b/src/common/base64.c @@ -42,10 +42,11 @@ static const int8 b64lookup[128] = { * pg_b64_encode * * Encode into base64 the given string. Returns the length of the encoded - * string. + * string, and -1 in the event of an error with the result buffer zeroed + * for safety. */ int -pg_b64_encode(const char *src, int len, char *dst) +pg_b64_encode(const char *src, int len, char *dst, int dstlen) { char *p; const char *s, @@ -65,6 +66,13 @@ pg_b64_encode(const char *src, int len, char *dst) /* write it out */ if (pos < 0) { + /* + * Leave if there is an overflow in the area allocated for the + * encoded string. + */ + if ((p - dst + 4) > dstlen) + goto error; + *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = _base64[(buf >> 6) & 0x3f]; @@ -76,23 +84,36 @@ pg_b64_encode(const char *src, int len, char *dst) } if (pos != 2) { + /* + * Leave if there is an overflow in the area allocated for the encoded + * string. + */ + if ((p - dst + 4) > dstlen) + goto error; + *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '='; *p++ = '='; } + Assert((p - dst) <= dstlen); return p - dst; + +error: + memset(dst, 0, dstlen); + return -1; } /* * pg_b64_decode * * Decode the given base64 string. Returns the length of the decoded - * string on success, and -1 in the event of an error. + * string on success, and -1 in the event of an error with the result + * buffer zeroed for safety. */ int -pg_b64_decode(const char *src, int len, char *dst) +pg_b64_decode(const char *src, int len, char *dst, int dstlen) { const char *srcend = src + len, *s = src; @@ -109,7 +130,7 @@ pg_b64_decode(const char *src, int len, char *dst) /* Leave if a whitespace is found */ if (c == ' ' || c == '\t' || c == '\n' || c == '\r') - return -1; + goto error; if (c == '=') { @@ -126,7 +147,7 @@ pg_b64_decode(const char *src, int len, char *dst) * Unexpected "=" character found while decoding base64 * sequence. */ - return -1; + goto error; } } b = 0; @@ -139,7 +160,7 @@ pg_b64_decode(const char *src, int len, char *dst) if (b < 0) { /* invalid symbol found */ - return -1; + goto error; } } /* add it to buffer */ @@ -147,11 +168,28 @@ pg_b64_decode(const char *src, int len, char *dst) pos++; if (pos == 4) { + /* + * Leave if there is an overflow in the area allocated for the + * decoded string. + */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = (buf >> 16) & 255; + if (end == 0 || end > 1) + { + /* overflow check */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = (buf >> 8) & 255; + } if (end == 0 || end > 2) + { + /* overflow check */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = buf & 255; + } buf = 0; pos = 0; } @@ -163,10 +201,15 @@ pg_b64_decode(const char *src, int len, char *dst) * base64 end sequence is invalid. Input data is missing padding, is * truncated or is otherwise corrupted. */ - return -1; + goto error; } + Assert((p - dst) <= dstlen); return p - dst; + +error: + memset(dst, 0, dstlen); + return -1; } /* diff --git a/src/common/scram-common.c b/src/common/scram-common.c index c30dfc97dc..dff9723e67 100644 --- a/src/common/scram-common.c +++ b/src/common/scram-common.c @@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, char *result; char *p; int maxlen; + int encoded_salt_len; + int encoded_stored_len; + int encoded_server_len; + int encoded_result; if (iterations <= 0) iterations = SCRAM_DEFAULT_ITERATIONS; @@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, * SCRAM-SHA-256$:$: *---------- */ + encoded_salt_len = pg_b64_enc_len(saltlen); + encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN); + encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN); + maxlen = strlen("SCRAM-SHA-256") + 1 + 10 + 1 /* iteration count */ - + pg_b64_enc_len(saltlen) + 1 /* Base64-encoded salt */ - + pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */ - + pg_b64_enc_len(SCRAM_KEY_LEN) + 1; /* Base64-encoded ServerKey */ + + encoded_salt_len + 1 /* Base64-encoded salt */ + + encoded_stored_len + 1 /* Base64-encoded StoredKey */ + + encoded_server_len + 1; /* Base64-encoded ServerKey */ #ifdef FRONTEND result = malloc(maxlen); @@ -231,11 +239,50 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations); - p += pg_b64_encode(salt, saltlen, p); + /* salt */ + encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode salt"); +#endif + } + p += encoded_result; *(p++) = '$'; - p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p); + + /* stored key */ + encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p, + encoded_stored_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode stored key"); +#endif + } + + p += encoded_result; *(p++) = ':'; - p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p); + + /* server key */ + encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p, + encoded_server_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode server key"); +#endif + } + + p += encoded_result; *(p++) = '\0'; Assert(p - result <= maxlen); diff --git a/src/include/common/base64.h b/src/include/common/base64.h index 1bae5ec966..c30b173483 100644 --- a/src/include/common/base64.h +++ b/src/include/common/base64.h @@ -11,8 +11,8 @@ #define BASE64_H /* base 64 */ -extern int pg_b64_encode(const char *src, int len, char *dst); -extern int pg_b64_decode(const char *src, int len, char *dst); +extern int pg_b64_encode(const char *src, int len, char *dst, int dstlen); +extern int pg_b64_decode(const char *src, int len, char *dst, int dstlen); extern int pg_b64_enc_len(int srclen); extern int pg_b64_dec_len(int srclen); diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c index 04ee43441c..7a8335bf9f 100644 --- a/src/interfaces/libpq/fe-auth-scram.c +++ b/src/interfaces/libpq/fe-auth-scram.c @@ -321,14 +321,23 @@ build_client_first_message(fe_scram_state *state) return NULL; } - state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1); + encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN); + /* don't forget the zero-terminator */ + state->client_nonce = malloc(encoded_len + 1); if (state->client_nonce == NULL) { printfPQExpBuffer(&conn->errorMessage, libpq_gettext("out of memory\n")); return NULL; } - encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, state->client_nonce); + encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, + state->client_nonce, encoded_len); + if (encoded_len < 0) + { + printfPQExpBuffer(&conn->errorMessage, + libpq_gettext("could not encode nonce\n")); + return NULL; + } state->client_nonce[encoded_len] = '\0'; /* @@ -406,6 +415,7 @@ build_client_final_message(fe_scram_state *state) PGconn *conn = state->conn; uint8 client_proof[SCRAM_KEY_LEN]; char *result; + int encoded_len; initPQExpBuffer(&buf); @@ -425,6 +435,7 @@ build_client_final_message(fe_scram_state *state) size_t cbind_header_len; char *cbind_input; size_t cbind_input_len; + int encoded_cbind_len; /* Fetch hash data of server's SSL certificate */ cbind_data = @@ -451,13 +462,26 @@ build_client_final_message(fe_scram_state *state) memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len); memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); - if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len))) + encoded_cbind_len = pg_b64_enc_len(cbind_input_len); + if (!enlargePQExpBuffer(&buf, encoded_cbind_len)) { free(cbind_data); free(cbind_input); goto oom_error; } - buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len); + encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len, + buf.data + buf.len, + encoded_cbind_len); + if (encoded_cbind_len < 0) + { + free(cbind_data); + free(cbind_input); + termPQExpBuffer(&buf); + printfPQExpBuffer(&conn->errorMessage, + "could not encode cbind data for channel binding\n"); + return NULL; + } + buf.len += encoded_cbind_len; buf.data[buf.len] = '\0'; free(cbind_data); @@ -497,11 +521,21 @@ build_client_final_message(fe_scram_state *state) client_proof); appendPQExpBufferStr(&buf, ",p="); - if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(SCRAM_KEY_LEN))) + encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN); + if (!enlargePQExpBuffer(&buf, encoded_len)) goto oom_error; - buf.len += pg_b64_encode((char *) client_proof, - SCRAM_KEY_LEN, - buf.data + buf.len); + encoded_len = pg_b64_encode((char *) client_proof, + SCRAM_KEY_LEN, + buf.data + buf.len, + encoded_len); + if (encoded_len < 0) + { + termPQExpBuffer(&buf); + printfPQExpBuffer(&conn->errorMessage, + libpq_gettext("could not encode client proof\n")); + return NULL; + } + buf.len += encoded_len; buf.data[buf.len] = '\0'; result = strdup(buf.data); @@ -529,6 +563,7 @@ read_server_first_message(fe_scram_state *state, char *input) char *endptr; char *encoded_salt; char *nonce; + int decoded_salt_len; state->server_first_message = strdup(input); if (state->server_first_message == NULL) @@ -570,7 +605,8 @@ read_server_first_message(fe_scram_state *state, char *input) /* read_attr_value() has generated an error string */ return false; } - state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt))); + decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt)); + state->salt = malloc(decoded_salt_len); if (state->salt == NULL) { printfPQExpBuffer(&conn->errorMessage, @@ -579,7 +615,8 @@ read_server_first_message(fe_scram_state *state, char *input) } state->saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), - state->salt); + state->salt, + decoded_salt_len); if (state->saltlen < 0) { printfPQExpBuffer(&conn->errorMessage, @@ -663,7 +700,8 @@ read_server_final_message(fe_scram_state *state, char *input) server_signature_len = pg_b64_decode(encoded_server_signature, strlen(encoded_server_signature), - decoded_server_signature); + decoded_server_signature, + server_signature_len); if (server_signature_len != SCRAM_KEY_LEN) { free(decoded_server_signature);