Refactor sendAuthRequest.

This way sendAuthRequest doesn't need to know the details of all the
different authentication methods. This is in preparation for adding SCRAM
authentication, which will add yet another authentication request message
type, with different payload.

Reviewed-By: Michael Paquier
Discussion: <CAB7nPqQvO4sxLFeS9D+NM3wpy08ieZdAj_6e117MQHZAfxBFsg@mail.gmail.com>
This commit is contained in:
Heikki Linnakangas 2016-08-18 13:25:31 +03:00
parent 07ef035129
commit 8d3b9cce81
1 changed files with 18 additions and 34 deletions

View File

@ -36,7 +36,8 @@
* Global authentication functions * Global authentication functions
*---------------------------------------------------------------- *----------------------------------------------------------------
*/ */
static void sendAuthRequest(Port *port, AuthRequest areq); static void sendAuthRequest(Port *port, AuthRequest areq, char *extradata,
int extralen);
static void auth_failed(Port *port, int status, char *logdetail); static void auth_failed(Port *port, int status, char *logdetail);
static char *recv_password_packet(Port *port); static char *recv_password_packet(Port *port);
static int recv_and_check_password_packet(Port *port, char **logdetail); static int recv_and_check_password_packet(Port *port, char **logdetail);
@ -498,7 +499,7 @@ ClientAuthentication(Port *port)
case uaGSS: case uaGSS:
#ifdef ENABLE_GSS #ifdef ENABLE_GSS
sendAuthRequest(port, AUTH_REQ_GSS); sendAuthRequest(port, AUTH_REQ_GSS, NULL, 0);
status = pg_GSS_recvauth(port); status = pg_GSS_recvauth(port);
#else #else
Assert(false); Assert(false);
@ -507,7 +508,7 @@ ClientAuthentication(Port *port)
case uaSSPI: case uaSSPI:
#ifdef ENABLE_SSPI #ifdef ENABLE_SSPI
sendAuthRequest(port, AUTH_REQ_SSPI); sendAuthRequest(port, AUTH_REQ_SSPI, NULL, 0);
status = pg_SSPI_recvauth(port); status = pg_SSPI_recvauth(port);
#else #else
Assert(false); Assert(false);
@ -531,12 +532,13 @@ ClientAuthentication(Port *port)
ereport(FATAL, ereport(FATAL,
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled"))); errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled")));
sendAuthRequest(port, AUTH_REQ_MD5); /* include the salt to use for computing the response */
sendAuthRequest(port, AUTH_REQ_MD5, port->md5Salt, 4);
status = recv_and_check_password_packet(port, &logdetail); status = recv_and_check_password_packet(port, &logdetail);
break; break;
case uaPassword: case uaPassword:
sendAuthRequest(port, AUTH_REQ_PASSWORD); sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
status = recv_and_check_password_packet(port, &logdetail); status = recv_and_check_password_packet(port, &logdetail);
break; break;
@ -583,7 +585,7 @@ ClientAuthentication(Port *port)
(*ClientAuthentication_hook) (port, status); (*ClientAuthentication_hook) (port, status);
if (status == STATUS_OK) if (status == STATUS_OK)
sendAuthRequest(port, AUTH_REQ_OK); sendAuthRequest(port, AUTH_REQ_OK, NULL, 0);
else else
auth_failed(port, status, logdetail); auth_failed(port, status, logdetail);
} }
@ -593,7 +595,7 @@ ClientAuthentication(Port *port)
* Send an authentication request packet to the frontend. * Send an authentication request packet to the frontend.
*/ */
static void static void
sendAuthRequest(Port *port, AuthRequest areq) sendAuthRequest(Port *port, AuthRequest areq, char *extradata, int extralen)
{ {
StringInfoData buf; StringInfoData buf;
@ -601,28 +603,8 @@ sendAuthRequest(Port *port, AuthRequest areq)
pq_beginmessage(&buf, 'R'); pq_beginmessage(&buf, 'R');
pq_sendint(&buf, (int32) areq, sizeof(int32)); pq_sendint(&buf, (int32) areq, sizeof(int32));
if (extralen > 0)
/* Add the salt for encrypted passwords. */ pq_sendbytes(&buf, extradata, extralen);
if (areq == AUTH_REQ_MD5)
pq_sendbytes(&buf, port->md5Salt, 4);
#if defined(ENABLE_GSS) || defined(ENABLE_SSPI)
/*
* Add the authentication data for the next step of the GSSAPI or SSPI
* negotiation.
*/
else if (areq == AUTH_REQ_GSS_CONT)
{
if (port->gss->outbuf.length > 0)
{
elog(DEBUG4, "sending GSS token of length %u",
(unsigned int) port->gss->outbuf.length);
pq_sendbytes(&buf, port->gss->outbuf.value, port->gss->outbuf.length);
}
}
#endif
pq_endmessage(&buf); pq_endmessage(&buf);
@ -934,7 +916,8 @@ pg_GSS_recvauth(Port *port)
elog(DEBUG4, "sending GSS response token of length %u", elog(DEBUG4, "sending GSS response token of length %u",
(unsigned int) port->gss->outbuf.length); (unsigned int) port->gss->outbuf.length);
sendAuthRequest(port, AUTH_REQ_GSS_CONT); sendAuthRequest(port, AUTH_REQ_GSS_CONT,
port->gss->outbuf.value, port->gss->outbuf.length);
gss_release_buffer(&lmin_s, &port->gss->outbuf); gss_release_buffer(&lmin_s, &port->gss->outbuf);
} }
@ -1179,7 +1162,8 @@ pg_SSPI_recvauth(Port *port)
port->gss->outbuf.length = outbuf.pBuffers[0].cbBuffer; port->gss->outbuf.length = outbuf.pBuffers[0].cbBuffer;
port->gss->outbuf.value = outbuf.pBuffers[0].pvBuffer; port->gss->outbuf.value = outbuf.pBuffers[0].pvBuffer;
sendAuthRequest(port, AUTH_REQ_GSS_CONT); sendAuthRequest(port, AUTH_REQ_GSS_CONT,
port->gss->outbuf.value, port->gss->outbuf.length);
FreeContextBuffer(outbuf.pBuffers[0].pvBuffer); FreeContextBuffer(outbuf.pBuffers[0].pvBuffer);
} }
@ -1807,7 +1791,7 @@ pam_passwd_conv_proc(int num_msg, const struct pam_message ** msg,
* let's go ask the client to send a password, which we * let's go ask the client to send a password, which we
* then stuff into PAM. * then stuff into PAM.
*/ */
sendAuthRequest(pam_port_cludge, AUTH_REQ_PASSWORD); sendAuthRequest(pam_port_cludge, AUTH_REQ_PASSWORD, NULL, 0);
passwd = recv_password_packet(pam_port_cludge); passwd = recv_password_packet(pam_port_cludge);
if (passwd == NULL) if (passwd == NULL)
{ {
@ -2137,7 +2121,7 @@ CheckLDAPAuth(Port *port)
if (port->hba->ldapport == 0) if (port->hba->ldapport == 0)
port->hba->ldapport = LDAP_PORT; port->hba->ldapport = LDAP_PORT;
sendAuthRequest(port, AUTH_REQ_PASSWORD); sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
passwd = recv_password_packet(port); passwd = recv_password_packet(port);
if (passwd == NULL) if (passwd == NULL)
@ -2497,7 +2481,7 @@ CheckRADIUSAuth(Port *port)
identifier = port->hba->radiusidentifier; identifier = port->hba->radiusidentifier;
/* Send regular password request to client, and get the response */ /* Send regular password request to client, and get the response */
sendAuthRequest(port, AUTH_REQ_PASSWORD); sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
passwd = recv_password_packet(port); passwd = recv_password_packet(port);
if (passwd == NULL) if (passwd == NULL)