diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile index 98eb2a8242d..32e4c7280e5 100644 --- a/src/backend/libpq/Makefile +++ b/src/backend/libpq/Makefile @@ -18,6 +18,8 @@ OBJS = \ auth-oauth.o \ auth-sasl.o \ auth-scram.o \ + auth-validate-methods.o \ + auth-validate.o \ auth.o \ be-fsstubs.o \ be-secure-common.o \ diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index 11365048951..c6a7840cb7a 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -892,3 +892,18 @@ done: return (*err_msg == NULL); } + +/* + * Check if an OAuth token has expired. + * This is called from credential validation to check token validity. + */ +bool +CheckOAuthValidatorExpiration(void) +{ + /* Delegate to validator's expire_cb if available */ + if (ValidatorCallbacks != NULL && ValidatorCallbacks->expire_cb != NULL) + return ValidatorCallbacks->expire_cb(validator_module_state); + + /* No expire_cb, assume valid */ + return true; +} diff --git a/src/backend/libpq/auth-validate-methods.c b/src/backend/libpq/auth-validate-methods.c new file mode 100644 index 00000000000..d7e1506c2a1 --- /dev/null +++ b/src/backend/libpq/auth-validate-methods.c @@ -0,0 +1,140 @@ +/*------------------------------------------------------------------------- + * + * auth-validate-methods.c + * Implementation of authentication credential validation methods + * + * This module provides credential validation methods for various authentication + * types during active PostgreSQL sessions. It includes validation for password + * expiry, OAuth token expiry, and can be extended to other authentication + * mechanisms. + * + * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * IDENTIFICATION + * src/backend/libpq/auth-validate-methods.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include "access/htup_details.h" +#include "access/xact.h" +#include "catalog/pg_authid.h" +#include "catalog/catalog.h" +#include "libpq/auth-validate.h" +#include "libpq/libpq-be.h" +#include "libpq/oauth.h" +#include "miscadmin.h" +#include "storage/lmgr.h" +#include "utils/syscache.h" +#include "utils/timestamp.h" + +/* Function declarations for internal use */ +static bool validate_password_credentials(void); +static bool validate_oauth_credentials(void); + +/* Function prototypes */ +void InitializeValidationMethods(void); + +/* + * Initialize validation methods + */ +void +InitializeValidationMethods(void) +{ + /* Register all the validation methods */ + RegisterCredentialValidator(CVT_PASSWORD, validate_password_credentials); + RegisterCredentialValidator(CVT_OAUTH, validate_oauth_credentials); +} + +/* + * Validate password credentials by checking rolvaliduntil + * Returns true if credentials are still valid, false if they have expired. + */ +static bool +validate_password_credentials(void) +{ + HeapTuple tuple = NULL; + Datum rolvaliduntil_datum; + bool validuntil_null; + TimestampTz valid_until = 0; + TimestampTz current_time; + Oid userid; + bool result = false; + + userid = GetSessionUserId(); + + /* + * Try to take AccessShareLock on pg_authid to prevent concurrent modifications + * from interfering with our validation. Use conditional acquisition to avoid + * indefinite waiting during credential validation. + */ + if (!ConditionalLockRelationOid(AuthIdRelationId, AccessShareLock)) + { + /* + * Could not acquire lock immediately, which likely means another session + * is modifying user data. For credential validation, it's better to + * consider credentials valid and retry later than to block indefinitely. + */ + elog(LOG, "credential validation: could not acquire lock on pg_authid immediately, will retry later"); + return true; /* Consider valid */ + } + + PG_TRY(); + { + tuple = SearchSysCache1(AUTHOID, ObjectIdGetDatum(userid)); + + if (HeapTupleIsValid(tuple)) + { + /* Get the expiration time column */ + rolvaliduntil_datum = SysCacheGetAttr(AUTHOID, tuple, + Anum_pg_authid_rolvaliduntil, + &validuntil_null); + if (!validuntil_null) + { + valid_until = DatumGetTimestampTz(rolvaliduntil_datum); + current_time = GetCurrentTimestamp(); + + result = !(valid_until < current_time); + } + else + result = true; + + ReleaseSysCache(tuple); + tuple = NULL; + } + } + PG_CATCH(); + { + if (tuple != NULL) + ReleaseSysCache(tuple); + + UnlockRelationOid(AuthIdRelationId, AccessShareLock); + PG_RE_THROW(); + } + PG_END_TRY(); + + /* Release the relation lock */ + UnlockRelationOid(AuthIdRelationId, AccessShareLock); + + return result; +} + +/* + * Check if an OAuth token has expired. + * + * Returns true if the token is still valid, false if it has expired. + * + * Calls wrapper CheckOAuthValidatorExpiration() function + * to verify that the token hasn't expired. + */ +static bool +validate_oauth_credentials(void) +{ + /* Call the validator's expire_cb to check token expiration */ + if (!CheckOAuthValidatorExpiration()) + return false; + + return true; +} diff --git a/src/backend/libpq/auth-validate.c b/src/backend/libpq/auth-validate.c new file mode 100644 index 00000000000..82c475f6df0 --- /dev/null +++ b/src/backend/libpq/auth-validate.c @@ -0,0 +1,244 @@ +/*------------------------------------------------------------------------- +* +* auth-validate.c +* Implementation of authentication credential validation +* +* This module provides a mechanism for validating credentials during +* an active PostgreSQL session. +* +* Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group +* Portions Copyright (c) 1994, Regents of the University of California +* +* IDENTIFICATION +* src/backend/libpq/auth-validate.c +* +*------------------------------------------------------------------------- +*/ +#include "postgres.h" + +#include "access/xact.h" +#include "access/xlog.h" +#include "libpq/auth.h" +#include "libpq/libpq-be.h" +#include "libpq/auth-validate.h" +#include "libpq/auth-validate-methods.h" +#include "miscadmin.h" +#include "postmaster/postmaster.h" +#include "storage/ipc.h" +#include "tcop/tcopprot.h" +#include "utils/elog.h" +#include "utils/guc.h" +#include "utils/timestamp.h" +#include "utils/timeout.h" + +/* GUC variables */ +bool credential_validation_enabled; +int credential_validation_interval; + +/* Registered credential validators */ +static CredentialValidationCallback validators[CVT_COUNT]; + +/* + * Convert UserAuth enum to CredentialValidationType for validator selection + */ +static CredentialValidationType +UserAuthToValidationType(UserAuth auth_method) +{ + switch (auth_method) + { + case uaPassword: + case uaMD5: + case uaSCRAM: + /* All password-based methods use the password validator */ + return CVT_PASSWORD; + case uaOAuth: + return CVT_OAUTH; + default: + /* No specific validator for other auth methods */ + return CVT_COUNT; /* Invalid value */ + } +} + +/* + * Process credential validation + */ +void +ProcessCredentialValidation(void) +{ + /* Skip validation during initialization, bootstrap, authentication or connection setup */ + if (ClientAuthInProgress || IsInitProcessingMode() || IsBootstrapProcessingMode()) + return; + + /* Check credentials if validation is enabled */ + if (credential_validation_enabled && MyClientConnectionInfo.authn_id != NULL) + { + CredentialValidationStatus status; + UserAuth auth_method = MyClientConnectionInfo.auth_method; + + status = CheckCredentialValidity(); + + switch (status) + { + case CVS_VALID: + /* Credentials are valid, continue */ + break; + + case CVS_EXPIRED: + elog(LOG, "credential validation: credentials expired for auth_method=%d", + (int) auth_method); + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("session credentials have expired"), + errhint("Please reconnect to establish a new authenticated session"))); + break; + + case CVS_ERROR: + elog(LOG, "credential validation: error checking credentials for auth_method=%d", + (int) auth_method); + ereport(WARNING, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("error checking credential validity"), + errhint("Credential validation will be retried at the next interval"))); + break; + } + } +} + +/* + * Initialize credential validation system Called from InitPostgres after + * authentication completes + */ +void +InitializeCredentialValidation(void) +{ + int i; + + /* Define GUC variables */ + DefineCustomBoolVariable("credential_validation.enabled", + "Enable periodic credential validation.", + NULL, + &credential_validation_enabled, + false, + PGC_SUSET, + 0, + NULL, + NULL, + NULL); + + DefineCustomIntVariable("credential_validation.interval", + "Credential validation interval in minutes.", + NULL, + &credential_validation_interval, + 1, /* default: 1 minute */ + 1, /* min: 1 minute */ + 60, /* max: 60 minutes */ + PGC_SUSET, + GUC_UNIT_MIN, + NULL, + NULL, + NULL); + + /* Initialize validator callbacks to NULL */ + for (i = 0; i < CVT_COUNT; i++) + validators[i] = NULL; + + /* Initialize and register all validation methods */ + InitializeValidationMethods(); +} + +/* + * Register a validator callback for a specific authentication method + */ +void +RegisterCredentialValidator(CredentialValidationType method_type, CredentialValidationCallback validator) +{ + if (method_type < 0 || method_type >= CVT_COUNT) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("invalid validation method type: %d", method_type))); + + validators[method_type] = validator; +} + +/* + * Check credential validity using the appropriate validator + */ +CredentialValidationStatus +CheckCredentialValidity(void) +{ + CredentialValidationCallback validator = NULL; + CredentialValidationStatus status; + + /* + * Skip validation for: + * - During shutdown or recovery + * - Non-client backends (any process not serving a client connection) + * - AutoVacuum processes (launcher and workers) + * - Background worker processes + * - Authentication is in progress + */ + if (proc_exit_inprogress || + RecoveryInProgress() || + !IsExternalConnectionBackend(MyBackendType) || + AmAutoVacuumLauncherProcess() || + AmAutoVacuumWorkerProcess() || + AmBackgroundWorkerProcess() || + ClientAuthInProgress) + return CVS_VALID; + /* + * Use the session's authentication method from MyClientConnectionInfo + * to select the appropriate validator. + */ + if (MyClientConnectionInfo.authn_id != NULL) + { + CredentialValidationType validation_type; + + validation_type = UserAuthToValidationType(MyClientConnectionInfo.auth_method); + + /* + * If we have a valid validation type, get the corresponding + * validator + */ + if (validation_type < CVT_COUNT) + validator = validators[validation_type]; + + } + + /* + * If no validator found for the current auth method or no + * authenticated session, skip validation and consider credentials + * valid + */ + if (validator == NULL || !MyClientConnectionInfo.authn_id) + return CVS_VALID; + + /* Call the validator and interpret result */ + PG_TRY(); + { + bool result; + + elog(DEBUG1, "credential validation: calling validator for auth_method=%d", + (int) MyClientConnectionInfo.auth_method); + + result = validator(); + + if (!result) + { + elog(DEBUG1, "credential validation: credentials expired"); + status = CVS_EXPIRED; /* Validator reports credentials expired */ + } + else + status = CVS_VALID; + + return status; + } + PG_CATCH(); + { + /* Error during validation */ + elog(DEBUG1, "credential validation: error during validation"); + + FlushErrorState(); + return CVS_ERROR; + } + PG_END_TRY(); +} diff --git a/src/backend/libpq/meson.build b/src/backend/libpq/meson.build index ee337cf42cc..608d9e10eb0 100644 --- a/src/backend/libpq/meson.build +++ b/src/backend/libpq/meson.build @@ -4,6 +4,8 @@ backend_sources += files( 'auth-oauth.c', 'auth-sasl.c', 'auth-scram.c', + 'auth-validate-methods.c', + 'auth-validate.c', 'auth.c', 'be-fsstubs.c', 'be-secure-common.c', diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c index d01a09dd0c4..7307660e85a 100644 --- a/src/backend/tcop/postgres.c +++ b/src/backend/tcop/postgres.c @@ -44,6 +44,7 @@ #include "libpq/libpq.h" #include "libpq/pqformat.h" #include "libpq/pqsignal.h" +#include "libpq/auth-validate.h" #include "mb/pg_wchar.h" #include "mb/stringinfo_mb.h" #include "miscadmin.h" @@ -97,6 +98,49 @@ bool Log_disconnections = false; int log_statement = LOGSTMT_NONE; + + +/* + * Function that performs credential validation when needed + * Uses a time-based approach to periodically validate credentials + * during normal operation, skipping validation in bootstrapping. + */ +static void +CheckAndExecuteCredentialValidation(void) +{ + TimestampTz now; + TimestampTz diff; + + /* Fast early returns for all cases where we should skip validation */ + if (IsInitProcessingMode() || IsBootstrapProcessingMode()) + return; + + /* Get the current time */ + now = GetCurrentTimestamp(); + + /* Use direct timestamp comparison for better performance */ + if (LastCredentialValidationTime != 0) + { + int64 interval_us; + + diff = now - LastCredentialValidationTime; + interval_us = (int64) credential_validation_interval * 60 * INT64CONST(1000000); /* minutes to microseconds */ + + /* Exit early if not enough time has passed */ + if (diff < interval_us) + return; + } + + /* Process credential validation */ + ProcessCredentialValidation(); + + /* Update the last validation time */ + LastCredentialValidationTime = now; + + /* Only log at DEBUG level to reduce noise */ + elog(DEBUG1, "Credential validation completed successfully"); +} + /* wait N seconds to allow attach from a debugger */ int PostAuthDelay = 0; @@ -1049,6 +1093,10 @@ exec_simple_query(const char *query_string) */ start_xact_command(); + /* Check and potentially execute credential validation using time-based approach */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* * Zap any pre-existing unnamed statement. (While not strictly necessary, * it seems best to define simple-Query mode as if it used the unnamed @@ -1430,6 +1478,10 @@ exec_parse_message(const char *query_string, /* string to execute */ */ start_xact_command(); + /* Check and potentially execute credential validation for extended protocol */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* * Switch to appropriate context for constructing parsetrees. * @@ -1705,6 +1757,10 @@ exec_bind_message(StringInfo input_message) */ start_xact_command(); + /* Check and potentially execute credential validation for extended protocol */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); @@ -2217,6 +2273,10 @@ exec_execute_message(const char *portal_name, long max_rows) */ start_xact_command(); + /* Check and potentially execute credential validation for extended protocol */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* * If we re-issue an Execute protocol request against an existing portal, * then we are only fetching more rows rather than completely re-executing @@ -2635,6 +2695,10 @@ exec_describe_statement_message(const char *stmt_name) */ start_xact_command(); + /* Check and potentially execute credential validation for extended protocol */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); @@ -2727,6 +2791,10 @@ exec_describe_portal_message(const char *portal_name) */ start_xact_command(); + /* Check and potentially execute credential validation for extended protocol */ + if (credential_validation_enabled && credential_validation_interval > 0 && IsNormalProcessingMode()) + CheckAndExecuteCredentialValidation(); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); diff --git a/src/backend/utils/init/globals.c b/src/backend/utils/init/globals.c index 36ad708b360..45beb71ef22 100644 --- a/src/backend/utils/init/globals.c +++ b/src/backend/utils/init/globals.c @@ -34,6 +34,7 @@ volatile sig_atomic_t QueryCancelPending = false; volatile sig_atomic_t ProcDiePending = false; volatile sig_atomic_t CheckClientConnectionPending = false; volatile sig_atomic_t ClientConnectionLost = false; +TimestampTz LastCredentialValidationTime = 0; volatile sig_atomic_t IdleInTransactionSessionTimeoutPending = false; volatile sig_atomic_t TransactionTimeoutPending = false; volatile sig_atomic_t IdleSessionTimeoutPending = false; diff --git a/src/backend/utils/init/postinit.c b/src/backend/utils/init/postinit.c index b59e08605cc..138fd440600 100644 --- a/src/backend/utils/init/postinit.c +++ b/src/backend/utils/init/postinit.c @@ -34,6 +34,7 @@ #include "catalog/pg_db_role_setting.h" #include "catalog/pg_tablespace.h" #include "libpq/auth.h" +#include "libpq/auth-validate.h" #include "libpq/libpq-be.h" #include "mb/pg_wchar.h" #include "miscadmin.h" @@ -1226,6 +1227,9 @@ InitPostgres(const char *in_dbname, Oid dboid, /* Initialize this backend's session state. */ InitializeSession(); + /* Initialize credential validation system */ + InitializeCredentialValidation(); + /* * If this is an interactive session, load any libraries that should be * preloaded at backend start. Since those are determined by GUCs, this @@ -1440,6 +1444,7 @@ ClientCheckTimeoutHandler(void) SetLatch(MyLatch); } + /* * Returns true if at least one role is defined in this database cluster. */ diff --git a/src/include/libpq/auth-validate-methods.h b/src/include/libpq/auth-validate-methods.h new file mode 100644 index 00000000000..420183a1c7d --- /dev/null +++ b/src/include/libpq/auth-validate-methods.h @@ -0,0 +1,25 @@ +/*------------------------------------------------------------------------- + * + * auth-validate-methods.h + * Interface for authentication credential validation methods + * + * This file provides declarations for various credential validation methods + * used with the credential validation system. + * + * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/include/libpq/auth-validate-methods.h + * + *------------------------------------------------------------------------- + */ +#ifndef AUTH_VALIDATE_METHODS_H +#define AUTH_VALIDATE_METHODS_H + +#include "libpq/libpq-be.h" +#include "utils/timestamp.h" + +/* Initialize all validation methods */ +extern void InitializeValidationMethods(void); + +#endif /* AUTH_VALIDATE_METHODS_H */ diff --git a/src/include/libpq/auth-validate.h b/src/include/libpq/auth-validate.h new file mode 100644 index 00000000000..52b17952744 --- /dev/null +++ b/src/include/libpq/auth-validate.h @@ -0,0 +1,61 @@ +/*------------------------------------------------------------------------- + * + * auth-validate.h + * Interface for authentication credential validation + * + * This file provides a common interface for validating credentials + * during an active PostgreSQL session. + * + * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/include/libpq/auth-validate.h + * + *------------------------------------------------------------------------- + */ +#ifndef AUTH_VALIDATE_H +#define AUTH_VALIDATE_H + +#include "libpq/libpq-be.h" +#include "libpq/protocol.h" +#include "postmaster/postmaster.h" +#include "utils/guc.h" +#include "utils/timeout.h" + +/* Define credential validation method types as an enum */ +typedef enum CredentialValidationType +{ + CVT_PASSWORD = 0, /* All password-based methods (md5, scram, etc) */ + CVT_OAUTH, /* OAuth bearer token authentication */ + CVT_COUNT /* Total number of credential validation types */ +} CredentialValidationType; + +/* Process credential validation */ +extern void ProcessCredentialValidation(void); + +/* GUC variables */ +extern PGDLLIMPORT bool credential_validation_enabled; +extern PGDLLIMPORT int credential_validation_interval; + +/* Common credential validation callback prototype */ +typedef bool (*CredentialValidationCallback) (void); + +/* Credential validation status */ +typedef enum CredentialValidationStatus +{ + CVS_VALID, /* Credentials are valid */ + CVS_EXPIRED, /* Credentials have expired */ + CVS_ERROR /* Error during validation */ +} CredentialValidationStatus; + +/* Initialize credential validation system */ +extern void InitializeCredentialValidation(void); + +/* Register a validation callback for a specific authentication method */ +extern void RegisterCredentialValidator(CredentialValidationType method_type, + CredentialValidationCallback validator); + +/* Check credential validity */ +extern CredentialValidationStatus CheckCredentialValidity(void); + +#endif /* AUTH_VALIDATE_H */ diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h index 4a822e9a1f2..bbb9290626c 100644 --- a/src/include/libpq/oauth.h +++ b/src/include/libpq/oauth.h @@ -64,6 +64,7 @@ typedef void (*ValidatorShutdownCB) (ValidatorModuleState *state); typedef bool (*ValidatorValidateCB) (const ValidatorModuleState *state, const char *token, const char *role, ValidatorModuleResult *result); +typedef bool (*ValidatorExpireCB) (const ValidatorModuleState *state); /* * Identifies the compiled ABI version of the validator module. Since the server @@ -80,6 +81,7 @@ typedef struct OAuthValidatorCallbacks ValidatorStartupCB startup_cb; ValidatorShutdownCB shutdown_cb; ValidatorValidateCB validate_cb; + ValidatorExpireCB expire_cb; /* Optional: Check token expiration */ } OAuthValidatorCallbacks; /* @@ -98,4 +100,8 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_oauth_mech; */ extern bool check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg); +/* + * Check OAuth token expiration using validator's expire_cb if available. + */ +bool CheckOAuthValidatorExpiration(void); #endif /* PG_OAUTH_H */ diff --git a/src/include/miscadmin.h b/src/include/miscadmin.h index f16f35659b9..30c3d40d418 100644 --- a/src/include/miscadmin.h +++ b/src/include/miscadmin.h @@ -99,6 +99,7 @@ extern PGDLLIMPORT volatile sig_atomic_t IdleStatsUpdateTimeoutPending; extern PGDLLIMPORT volatile sig_atomic_t CheckClientConnectionPending; extern PGDLLIMPORT volatile sig_atomic_t ClientConnectionLost; +extern PGDLLIMPORT TimestampTz LastCredentialValidationTime; /* these are marked volatile because they are examined by signal handlers: */ extern PGDLLIMPORT volatile uint32 InterruptHoldoffCount;