fe_scram_state_enum state;
/* These are supplied by the user */
- const char *username;
+ PGconn *conn;
char *password;
- bool ssl_in_use;
- char *tls_finished_message;
- size_t tls_finished_len;
char *sasl_mechanism;
- const char *channel_binding_type;
/* We construct these */
uint8 SaltedPassword[SCRAM_KEY_LEN];
char ServerSignature[SCRAM_KEY_LEN];
} fe_scram_state;
-static bool read_server_first_message(fe_scram_state *state, char *input,
- PQExpBuffer errormessage);
-static bool read_server_final_message(fe_scram_state *state, char *input,
- PQExpBuffer errormessage);
-static char *build_client_first_message(fe_scram_state *state,
- PQExpBuffer errormessage);
-static char *build_client_final_message(fe_scram_state *state,
- PQExpBuffer errormessage);
+static bool read_server_first_message(fe_scram_state *state, char *input);
+static bool read_server_final_message(fe_scram_state *state, char *input);
+static char *build_client_first_message(fe_scram_state *state);
+static char *build_client_final_message(fe_scram_state *state);
static bool verify_server_signature(fe_scram_state *state);
static void calculate_client_proof(fe_scram_state *state,
const char *client_final_message_without_proof,
/*
* Initialize SCRAM exchange status.
- *
- * The non-const char* arguments should be passed in malloc'ed. They will be
- * freed by pg_fe_scram_free().
*/
void *
-pg_fe_scram_init(const char *username,
+pg_fe_scram_init(PGconn *conn,
const char *password,
- bool ssl_in_use,
- const char *sasl_mechanism,
- const char *channel_binding_type,
- char *tls_finished_message,
- size_t tls_finished_len)
+ const char *sasl_mechanism)
{
fe_scram_state *state;
char *prep_password;
if (!state)
return NULL;
memset(state, 0, sizeof(fe_scram_state));
+ state->conn = conn;
state->state = FE_SCRAM_INIT;
- state->username = username;
- state->ssl_in_use = ssl_in_use;
- state->tls_finished_message = tls_finished_message;
- state->tls_finished_len = tls_finished_len;
state->sasl_mechanism = strdup(sasl_mechanism);
- state->channel_binding_type = channel_binding_type;
if (!state->sasl_mechanism)
{
if (state->password)
free(state->password);
- if (state->tls_finished_message)
- free(state->tls_finished_message);
if (state->sasl_mechanism)
free(state->sasl_mechanism);
void
pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen,
- bool *done, bool *success, PQExpBuffer errorMessage)
+ bool *done, bool *success)
{
fe_scram_state *state = (fe_scram_state *) opaq;
+ PGconn *conn = state->conn;
*done = false;
*success = false;
{
if (inputlen == 0)
{
- printfPQExpBuffer(errorMessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (empty message)\n"));
goto error;
}
if (inputlen != strlen(input))
{
- printfPQExpBuffer(errorMessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (length mismatch)\n"));
goto error;
}
{
case FE_SCRAM_INIT:
/* Begin the SCRAM handshake, by sending client nonce */
- *output = build_client_first_message(state, errorMessage);
+ *output = build_client_first_message(state);
if (*output == NULL)
goto error;
case FE_SCRAM_NONCE_SENT:
/* Receive salt and server nonce, send response. */
- if (!read_server_first_message(state, input, errorMessage))
+ if (!read_server_first_message(state, input))
goto error;
- *output = build_client_final_message(state, errorMessage);
+ *output = build_client_final_message(state);
if (*output == NULL)
goto error;
case FE_SCRAM_PROOF_SENT:
/* Receive server signature */
- if (!read_server_final_message(state, input, errorMessage))
+ if (!read_server_final_message(state, input))
goto error;
/*
else
{
*success = false;
- printfPQExpBuffer(errorMessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("incorrect server signature\n"));
}
*done = true;
default:
/* shouldn't happen */
- printfPQExpBuffer(errorMessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid SCRAM exchange state\n"));
goto error;
}
* Build the first exchange message sent by the client.
*/
static char *
-build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
+build_client_first_message(fe_scram_state *state)
{
+ PGconn *conn = state->conn;
char raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
char *result;
int channel_info_len;
*/
if (!pg_frontend_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("could not generate nonce\n"));
return NULL;
}
state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
if (state->client_nonce == NULL)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return NULL;
}
*/
if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{
- Assert(state->ssl_in_use);
- appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type);
+ Assert(conn->ssl_in_use);
+ appendPQExpBuffer(&buf, "p=%s", conn->scram_channel_binding);
}
- else if (state->channel_binding_type == NULL ||
- strlen(state->channel_binding_type) == 0)
+ else if (conn->scram_channel_binding == NULL ||
+ strlen(conn->scram_channel_binding) == 0)
{
/*
* Client has chosen to not show to server that it supports channel
*/
appendPQExpBuffer(&buf, "n");
}
- else if (state->ssl_in_use)
+ else if (conn->ssl_in_use)
{
/*
* Client supports channel binding, but thinks the server does not.
oom_error:
termPQExpBuffer(&buf);
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return NULL;
}
* Build the final exchange message sent from the client.
*/
static char *
-build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
+build_client_final_message(fe_scram_state *state)
{
PQExpBufferData buf;
+ PGconn *conn = state->conn;
uint8 client_proof[SCRAM_KEY_LEN];
char *result;
*/
if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{
- char *cbind_data;
- size_t cbind_data_len;
+ char *cbind_data = NULL;
+ size_t cbind_data_len = 0;
size_t cbind_header_len;
char *cbind_input;
size_t cbind_input_len;
- if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+ if (strcmp(conn->scram_channel_binding, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
{
- cbind_data = state->tls_finished_message;
- cbind_data_len = state->tls_finished_len;
+#ifdef USE_SSL
+ cbind_data = pgtls_get_finished(state->conn, &cbind_data_len);
+ if (cbind_data == NULL)
+ goto oom_error;
+#endif
}
else
{
/* should not happen */
termPQExpBuffer(&buf);
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid channel binding type\n"));
return NULL;
}
/* should not happen */
if (cbind_data == NULL || cbind_data_len == 0)
{
+ if (cbind_data != NULL)
+ free(cbind_data);
termPQExpBuffer(&buf);
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("empty channel binding data for channel binding type \"%s\"\n"),
- state->channel_binding_type);
+ conn->scram_channel_binding);
return NULL;
}
appendPQExpBuffer(&buf, "c=");
- cbind_header_len = 4 + strlen(state->channel_binding_type); /* p=type,, */
+ /* p=type,, */
+ cbind_header_len = 4 + strlen(conn->scram_channel_binding);
cbind_input_len = cbind_header_len + cbind_data_len;
cbind_input = malloc(cbind_input_len);
if (!cbind_input)
+ {
+ free(cbind_data);
goto oom_error;
- snprintf(cbind_input, cbind_input_len, "p=%s,,", state->channel_binding_type);
+ }
+ snprintf(cbind_input, cbind_input_len, "p=%s,,",
+ conn->scram_channel_binding);
memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
{
+ free(cbind_data);
free(cbind_input);
goto oom_error;
}
buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
buf.data[buf.len] = '\0';
+ free(cbind_data);
free(cbind_input);
}
- else if (state->channel_binding_type == NULL ||
- strlen(state->channel_binding_type) == 0)
+ else if (conn->scram_channel_binding == NULL ||
+ strlen(conn->scram_channel_binding) == 0)
appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
- else if (state->ssl_in_use)
+ else if (conn->ssl_in_use)
appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */
else
appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
oom_error:
termPQExpBuffer(&buf);
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return NULL;
}
* Read the first exchange message coming from the server.
*/
static bool
-read_server_first_message(fe_scram_state *state, char *input,
- PQExpBuffer errormessage)
+read_server_first_message(fe_scram_state *state, char *input)
{
+ PGconn *conn = state->conn;
char *iterations_str;
char *endptr;
char *encoded_salt;
state->server_first_message = strdup(input);
if (state->server_first_message == NULL)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return false;
}
/* parse the message */
- nonce = read_attr_value(&input, 'r', errormessage);
+ nonce = read_attr_value(&input, 'r',
+ &conn->errorMessage);
if (nonce == NULL)
{
/* read_attr_value() has generated an error string */
if (strlen(nonce) < strlen(state->client_nonce) ||
memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid SCRAM response (nonce mismatch)\n"));
return false;
}
state->nonce = strdup(nonce);
if (state->nonce == NULL)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return false;
}
- encoded_salt = read_attr_value(&input, 's', errormessage);
+ encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
if (encoded_salt == NULL)
{
/* read_attr_value() has generated an error string */
state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
if (state->salt == NULL)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return false;
}
strlen(encoded_salt),
state->salt);
- iterations_str = read_attr_value(&input, 'i', errormessage);
+ iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
if (iterations_str == NULL)
{
/* read_attr_value() has generated an error string */
state->iterations = strtol(iterations_str, &endptr, 10);
if (*endptr != '\0' || state->iterations < 1)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (invalid iteration count)\n"));
return false;
}
if (*input != '\0')
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (garbage at end of server-first-message)\n"));
return true;
* Read the final exchange message coming from the server.
*/
static bool
-read_server_final_message(fe_scram_state *state, char *input,
- PQExpBuffer errormessage)
+read_server_final_message(fe_scram_state *state, char *input)
{
+ PGconn *conn = state->conn;
char *encoded_server_signature;
int server_signature_len;
state->server_final_message = strdup(input);
if (!state->server_final_message)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n"));
return false;
}
/* Check for error result. */
if (*input == 'e')
{
- char *errmsg = read_attr_value(&input, 'e', errormessage);
+ char *errmsg = read_attr_value(&input, 'e',
+ &conn->errorMessage);
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("error received from server in SCRAM exchange: %s\n"),
errmsg);
return false;
}
/* Parse the message. */
- encoded_server_signature = read_attr_value(&input, 'v', errormessage);
+ encoded_server_signature = read_attr_value(&input, 'v',
+ &conn->errorMessage);
if (encoded_server_signature == NULL)
{
/* read_attr_value() has generated an error message */
}
if (*input != '\0')
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
server_signature_len = pg_b64_decode(encoded_server_signature,
state->ServerSignature);
if (server_signature_len != SCRAM_KEY_LEN)
{
- printfPQExpBuffer(errormessage,
+ printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
return false;
}