Refactor SASL exchange to return tri-state status

The SASL exchange callback returned state in to output variables:
done and success.  This refactors that logic by introducing a new
return variable of type SASLStatus which makes the code easier to
read and understand, and prepares for future SASL exchanges which
operate asynchronously.

This was extracted from a larger patchset to introduce OAuthBearer
authentication and authorization.

Author: Jacob Champion <jacob.champion@enterprisedb.com>
Discussion: https://postgr.es/m/d1b467a78e0e36ed85a09adf979d04cf124a9d4b.camel@vmware.com
This commit is contained in:
Daniel Gustafsson 2024-03-21 14:45:46 +01:00
parent 1db689715d
commit 24178e235e
4 changed files with 70 additions and 66 deletions

View File

@ -21,6 +21,17 @@
#include "libpq-fe.h" #include "libpq-fe.h"
/*
* Possible states for the SASL exchange, see the comment on exchange for an
* explanation of these.
*/
typedef enum
{
SASL_COMPLETE = 0,
SASL_FAILED,
SASL_CONTINUE,
} SASLStatus;
/* /*
* Frontend SASL mechanism callbacks. * Frontend SASL mechanism callbacks.
* *
@ -59,7 +70,8 @@ typedef struct pg_fe_sasl_mech
* Produces a client response to a server challenge. As a special case * Produces a client response to a server challenge. As a special case
* for client-first SASL mechanisms, exchange() is called with a NULL * for client-first SASL mechanisms, exchange() is called with a NULL
* server response once at the start of the authentication exchange to * server response once at the start of the authentication exchange to
* generate an initial response. * generate an initial response. Returns a SASLStatus indicating the
* state and status of the exchange.
* *
* Input parameters: * Input parameters:
* *
@ -79,22 +91,23 @@ typedef struct pg_fe_sasl_mech
* *
* output: A malloc'd buffer containing the client's response to * output: A malloc'd buffer containing the client's response to
* the server (can be empty), or NULL if the exchange should * the server (can be empty), or NULL if the exchange should
* be aborted. (*success should be set to false in the * be aborted. (The callback should return SASL_FAILED in the
* latter case.) * latter case.)
* *
* outputlen: The length (0 or higher) of the client response buffer, * outputlen: The length (0 or higher) of the client response buffer,
* ignored if output is NULL. * ignored if output is NULL.
* *
* done: Set to true if the SASL exchange should not continue, * Return value:
* because the exchange is either complete or failed
* *
* success: Set to true if the SASL exchange completed successfully. * SASL_CONTINUE: The output buffer is filled with a client response.
* Ignored if *done is false. * Additional server challenge is expected
* SASL_COMPLETE: The SASL exchange has completed successfully.
* SASL_FAILED: The exchange has failed and the connection should be
* dropped.
*-------- *--------
*/ */
void (*exchange) (void *state, char *input, int inputlen, SASLStatus (*exchange) (void *state, char *input, int inputlen,
char **output, int *outputlen, char **output, int *outputlen);
bool *done, bool *success);
/*-------- /*--------
* channel_bound() * channel_bound()

View File

@ -24,9 +24,8 @@
/* The exported SCRAM callback mechanism. */ /* The exported SCRAM callback mechanism. */
static void *scram_init(PGconn *conn, const char *password, static void *scram_init(PGconn *conn, const char *password,
const char *sasl_mechanism); const char *sasl_mechanism);
static void scram_exchange(void *opaq, char *input, int inputlen, static SASLStatus scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen, char **output, int *outputlen);
bool *done, bool *success);
static bool scram_channel_bound(void *opaq); static bool scram_channel_bound(void *opaq);
static void scram_free(void *opaq); static void scram_free(void *opaq);
@ -202,17 +201,14 @@ scram_free(void *opaq)
/* /*
* Exchange a SCRAM message with backend. * Exchange a SCRAM message with backend.
*/ */
static void static SASLStatus
scram_exchange(void *opaq, char *input, int inputlen, scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen, char **output, int *outputlen)
bool *done, bool *success)
{ {
fe_scram_state *state = (fe_scram_state *) opaq; fe_scram_state *state = (fe_scram_state *) opaq;
PGconn *conn = state->conn; PGconn *conn = state->conn;
const char *errstr = NULL; const char *errstr = NULL;
*done = false;
*success = false;
*output = NULL; *output = NULL;
*outputlen = 0; *outputlen = 0;
@ -225,12 +221,12 @@ scram_exchange(void *opaq, char *input, int inputlen,
if (inputlen == 0) if (inputlen == 0)
{ {
libpq_append_conn_error(conn, "malformed SCRAM message (empty message)"); libpq_append_conn_error(conn, "malformed SCRAM message (empty message)");
goto error; return SASL_FAILED;
} }
if (inputlen != strlen(input)) if (inputlen != strlen(input))
{ {
libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)"); libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)");
goto error; return SASL_FAILED;
} }
} }
@ -240,61 +236,59 @@ scram_exchange(void *opaq, char *input, int inputlen,
/* Begin the SCRAM handshake, by sending client nonce */ /* Begin the SCRAM handshake, by sending client nonce */
*output = build_client_first_message(state); *output = build_client_first_message(state);
if (*output == NULL) if (*output == NULL)
goto error; return SASL_FAILED;
*outputlen = strlen(*output); *outputlen = strlen(*output);
*done = false;
state->state = FE_SCRAM_NONCE_SENT; state->state = FE_SCRAM_NONCE_SENT;
break; return SASL_CONTINUE;
case FE_SCRAM_NONCE_SENT: case FE_SCRAM_NONCE_SENT:
/* Receive salt and server nonce, send response. */ /* Receive salt and server nonce, send response. */
if (!read_server_first_message(state, input)) if (!read_server_first_message(state, input))
goto error; return SASL_FAILED;
*output = build_client_final_message(state); *output = build_client_final_message(state);
if (*output == NULL) if (*output == NULL)
goto error; return SASL_FAILED;
*outputlen = strlen(*output); *outputlen = strlen(*output);
*done = false;
state->state = FE_SCRAM_PROOF_SENT; state->state = FE_SCRAM_PROOF_SENT;
break; return SASL_CONTINUE;
case FE_SCRAM_PROOF_SENT: case FE_SCRAM_PROOF_SENT:
/* Receive server signature */
if (!read_server_final_message(state, input))
goto error;
/*
* Verify server signature, to make sure we're talking to the
* genuine server.
*/
if (!verify_server_signature(state, success, &errstr))
{ {
libpq_append_conn_error(conn, "could not verify server signature: %s", errstr); bool match;
goto error;
}
if (!*success) /* Receive server signature */
{ if (!read_server_final_message(state, input))
libpq_append_conn_error(conn, "incorrect server signature"); return SASL_FAILED;
/*
* Verify server signature, to make sure we're talking to the
* genuine server.
*/
if (!verify_server_signature(state, &match, &errstr))
{
libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
return SASL_FAILED;
}
if (!match)
{
libpq_append_conn_error(conn, "incorrect server signature");
}
state->state = FE_SCRAM_FINISHED;
state->conn->client_finished_auth = true;
return match ? SASL_COMPLETE : SASL_FAILED;
} }
*done = true;
state->state = FE_SCRAM_FINISHED;
state->conn->client_finished_auth = true;
break;
default: default:
/* shouldn't happen */ /* shouldn't happen */
libpq_append_conn_error(conn, "invalid SCRAM exchange state"); libpq_append_conn_error(conn, "invalid SCRAM exchange state");
goto error; break;
} }
return;
error: return SASL_FAILED;
*done = true;
*success = false;
} }
/* /*

View File

@ -423,11 +423,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
{ {
char *initialresponse = NULL; char *initialresponse = NULL;
int initialresponselen; int initialresponselen;
bool done;
bool success;
const char *selected_mechanism; const char *selected_mechanism;
PQExpBufferData mechanism_buf; PQExpBufferData mechanism_buf;
char *password; char *password;
SASLStatus status;
initPQExpBuffer(&mechanism_buf); initPQExpBuffer(&mechanism_buf);
@ -575,12 +574,11 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto oom_error; goto oom_error;
/* Get the mechanism-specific Initial Client Response, if any */ /* Get the mechanism-specific Initial Client Response, if any */
conn->sasl->exchange(conn->sasl_state, status = conn->sasl->exchange(conn->sasl_state,
NULL, -1, NULL, -1,
&initialresponse, &initialresponselen, &initialresponse, &initialresponselen);
&done, &success);
if (done && !success) if (status == SASL_FAILED)
goto error; goto error;
/* /*
@ -629,10 +627,9 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
{ {
char *output; char *output;
int outputlen; int outputlen;
bool done;
bool success;
int res; int res;
char *challenge; char *challenge;
SASLStatus status;
/* Read the SASL challenge from the AuthenticationSASLContinue message. */ /* Read the SASL challenge from the AuthenticationSASLContinue message. */
challenge = malloc(payloadlen + 1); challenge = malloc(payloadlen + 1);
@ -651,13 +648,12 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
/* For safety and convenience, ensure the buffer is NULL-terminated. */ /* For safety and convenience, ensure the buffer is NULL-terminated. */
challenge[payloadlen] = '\0'; challenge[payloadlen] = '\0';
conn->sasl->exchange(conn->sasl_state, status = conn->sasl->exchange(conn->sasl_state,
challenge, payloadlen, challenge, payloadlen,
&output, &outputlen, &output, &outputlen);
&done, &success);
free(challenge); /* don't need the input anymore */ free(challenge); /* don't need the input anymore */
if (final && !done) if (final && status == SASL_CONTINUE)
{ {
if (outputlen != 0) if (outputlen != 0)
free(output); free(output);
@ -670,7 +666,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
* If the exchange is not completed yet, we need to make sure that the * If the exchange is not completed yet, we need to make sure that the
* SASL mechanism has generated a message to send back. * SASL mechanism has generated a message to send back.
*/ */
if (output == NULL && !done) if (output == NULL && status == SASL_CONTINUE)
{ {
libpq_append_conn_error(conn, "no client response found after SASL exchange success"); libpq_append_conn_error(conn, "no client response found after SASL exchange success");
return STATUS_ERROR; return STATUS_ERROR;
@ -692,7 +688,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
return STATUS_ERROR; return STATUS_ERROR;
} }
if (done && !success) if (status == SASL_FAILED)
return STATUS_ERROR; return STATUS_ERROR;
return STATUS_OK; return STATUS_OK;

View File

@ -2442,6 +2442,7 @@ RuleLock
RuleStmt RuleStmt
RunningTransactions RunningTransactions
RunningTransactionsData RunningTransactionsData
SASLStatus
SC_HANDLE SC_HANDLE
SECURITY_ATTRIBUTES SECURITY_ATTRIBUTES
SECURITY_STATUS SECURITY_STATUS