diff --git a/src/backend/libpq/auth-oauth.c b/src/backend/libpq/auth-oauth.c index 11365048951..6e684e9bd0d 100644 --- a/src/backend/libpq/auth-oauth.c +++ b/src/backend/libpq/auth-oauth.c @@ -684,6 +684,13 @@ validate(Port *port, const char *auth) goto cleanup; } + /* + * Store the validator's expiration callback and timestamp in the Port + * structure to allow for session-wide validity enforcement. + */ + port->expired_cb = ret->expired_cb; + port->expiry = ret->expiry; + if (port->hba->oauth_skip_usermap) { /* diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c index 6570f27297b..9f1be07bf69 100644 --- a/src/backend/libpq/pqcomm.c +++ b/src/backend/libpq/pqcomm.c @@ -319,6 +319,12 @@ pq_init(ClientSocket *client_sock) Assert(socket_pos == FeBeWaitSetSocketPos); Assert(latch_pos == FeBeWaitSetLatchPos); + /* + * Initialize OAuth session fields to safe defaults (no expiry/no callback). + */ + port->expiry = DT_NOBEGIN; + port->expired_cb = NULL; + return port; } diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c index 21de158adbb..291810f4c59 100644 --- a/src/backend/tcop/postgres.c +++ b/src/backend/tcop/postgres.c @@ -185,7 +185,7 @@ static void report_recovery_conflict(RecoveryConflictReason reason); static void log_disconnections(int code, Datum arg); static void enable_statement_timeout(void); static void disable_statement_timeout(void); - +static void check_oauth_expiry(Port *port); /* ---------------------------------------------------------------- * infrastructure for valgrind debugging @@ -1049,6 +1049,13 @@ exec_simple_query(const char *query_string) */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* * Zap any pre-existing unnamed statement. (While not strictly necessary, * it seems best to define simple-Query mode as if it used the unnamed @@ -1430,6 +1437,13 @@ exec_parse_message(const char *query_string, /* string to execute */ */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* * Switch to appropriate context for constructing parsetrees. * @@ -1705,6 +1719,13 @@ exec_bind_message(StringInfo input_message) */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); @@ -2217,6 +2238,13 @@ exec_execute_message(const char *portal_name, long max_rows) */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* * If we re-issue an Execute protocol request against an existing portal, * then we are only fetching more rows rather than completely re-executing @@ -2635,6 +2663,13 @@ exec_describe_statement_message(const char *stmt_name) */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); @@ -2727,6 +2762,13 @@ exec_describe_portal_message(const char *portal_name) */ start_xact_command(); + /* + * If the current session was authenticated via OAuth, verify that the + * token has not expired or been revoked before executing the query. + */ + if (MyClientConnectionInfo.auth_method == uaOAuth) + check_oauth_expiry(MyProcPort); + /* Switch back to message context */ MemoryContextSwitchTo(MessageContext); @@ -5271,3 +5313,19 @@ disable_statement_timeout(void) if (get_timeout_active(STATEMENT_TIMEOUT)) disable_timeout(STATEMENT_TIMEOUT, false); } + +/* + * Validates the current OAuth session. If a validator has provided a + * callback, execute it. A return value of 'true' triggers a FATAL + * error to terminate the session immediately. + */ +static void +check_oauth_expiry(Port *port) +{ + if (port->expired_cb != NULL && port->expired_cb(port->expiry)) + { + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("session expired: OAuth token is no longer valid"))); + } +} diff --git a/src/include/libpq/libpq-be.h b/src/include/libpq/libpq-be.h index 921b2daa4ff..388e3e8d8ba 100644 --- a/src/include/libpq/libpq-be.h +++ b/src/include/libpq/libpq-be.h @@ -238,6 +238,16 @@ typedef struct Port char *raw_buf; ssize_t raw_buf_consumed, raw_buf_remaining; + + /* + * The expiration time of the authentication credential. + * If not it represents the point in time after which the current session is + * considered invalid. + */ + TimestampTz expiry; + + /* Callback to verify session validity at runtime */ + bool (*expired_cb) (TimestampTz); } Port; /* diff --git a/src/include/libpq/oauth.h b/src/include/libpq/oauth.h index 4a822e9a1f2..c1d278590e1 100644 --- a/src/include/libpq/oauth.h +++ b/src/include/libpq/oauth.h @@ -49,6 +49,22 @@ typedef struct ValidatorModuleResult * delegation. See the validator module documentation for details. */ char *authn_id; + + /* + * Optional callback to check if the session is still valid. + * Returns true if the token is expired/revoked, false otherwise. + * If NULL, the backend assumes the session never expires. + * If provided, the validator can use this to limit session duration based on + * parameter value or based on it's custom logic. + */ + bool (*expired_cb) (TimestampTz expiry); + + /* + * The expiration time of the token (e.g., from the 'exp' claim) if + * provided. This value is passed as an argument to the expired_cb function + * above to determine if the session should terminate. + */ + TimestampTz expiry; } ValidatorModuleResult; /* diff --git a/src/test/modules/oauth_validator/t/003_token_expiry.pl b/src/test/modules/oauth_validator/t/003_token_expiry.pl new file mode 100755 index 00000000000..ec81c248d70 --- /dev/null +++ b/src/test/modules/oauth_validator/t/003_token_expiry.pl @@ -0,0 +1,148 @@ +# +# Test OAuth token expiration implementation +# This test verifies that when an OAuth token expires or the validator callback +# indicates it has been revoked, the session is properly terminated. +# + +use strict; +use warnings; +use JSON::PP qw(encode_json); +use MIME::Base64 qw(encode_base64); +use PostgreSQL::Test::Cluster; +use PostgreSQL::Test::Utils; +use Test::More; +use FindBin; +use lib $FindBin::RealBin; +use OAuth::Server; + +# Skip tests if environment doesn't support them +if (!$ENV{PG_TEST_EXTRA} || $ENV{PG_TEST_EXTRA} !~ /\boauth\b/) +{ + plan skip_all => + 'Potentially unsafe test oauth not enabled in PG_TEST_EXTRA'; +} + +unless (check_pg_config("#define HAVE_SYS_EVENT_H 1") + or check_pg_config("#define HAVE_SYS_EPOLL_H 1")) +{ + plan skip_all => + 'OAuth server-side tests are not supported on this platform'; +} + +if ($ENV{with_libcurl} ne 'yes') +{ + plan skip_all => 'client-side OAuth not supported by this build'; +} + +if ($ENV{with_python} ne 'yes') +{ + plan skip_all => 'OAuth tests require --with-python to run'; +} + +# This test validates that the OAuth token expiration mechanism +# is properly implemented by examining log entries. +# Set environment variables for test execution +# Use the default admin user from the test environment +# This is typically determined by the PostgreSQL::Test::Cluster module + +plan tests => 4; + +# Create a PostgreSQL instance for testing +my $node = PostgreSQL::Test::Cluster->new('oauth_expiry'); +$node->init; +$node->append_conf('postgresql.conf', "log_connections = on"); +$node->append_conf('postgresql.conf', "log_disconnections = on"); +$node->append_conf('postgresql.conf', "oauth_validator_libraries = 'validator'\n"); +$node->start; + +# Create test users +$node->safe_psql('postgres', 'CREATE USER test;'); + +# Start the mock OAuth server +my $webserver = OAuth::Server->new(); +$webserver->run(); + +END +{ + my $exit_code = $?; + $webserver->stop() if defined $webserver; + $? = $exit_code; +} + +my $port = $webserver->port(); +my $issuer = "http://127.0.0.1:$port"; + +# Configure HBA for OAuth authentication +unlink($node->data_dir . '/pg_hba.conf'); +# First, add a specific rule for the test user with OAuth authentication +$node->append_conf( + 'pg_hba.conf', qq{ +# OAuth authentication for test user (this must be the first rule) +local postgres test oauth validator=validator issuer="$issuer" scope="openid postgres" +}); +# Add a separate trust rule for the admin user (after the OAuth rule) +$node->append_conf( + 'pg_hba.conf', qq{ +# Trust authentication for admin access +local all all trust +}); + +$node->reload; + +# Get log start position to track new log entries +my $log_start = $node->wait_for_log(qr/reloading configuration files/); + +# Create a background connection for configuration changes +my $bgconn = $node->background_psql('postgres'); +ok($bgconn, "Background admin connection established"); + +# Enable OAuth token expiration test mode +$bgconn->query_safe("ALTER SYSTEM SET oauth_validator.enable_expiry_test TO true"); +$bgconn->query_safe("ALTER SYSTEM SET oauth_validator.token_expires_in TO 2"); +$node->reload; +$node->wait_for_log(qr/reloading configuration files/, $log_start); + +# Update log position after reload +$log_start = $node->wait_for_log(qr/parameter "oauth_validator.token_expires_in" changed to "2"/); + +# Enable OAuth debug mode for connection testing +# This is required for the test to use OAuth authentication +$ENV{PGOAUTHDEBUG} = "UNSAFE"; + +# Make a connection with OAuth auth +$node->connect_ok( + "user=test dbname=postgres oauth_issuer=$issuer oauth_client_id=f02c6361-0635", + "connect with OAuth token", + # Allow any stderr output since OAuth debugging will produce messages + expected_stderr => qr/.*/ +); + +# Wait for token to expire +note "Waiting for token to expire (3 seconds)..."; +sleep 3; + +# Ensure the OAuth debug environment variable is set before trying the second connection +$ENV{PGOAUTHDEBUG} = "UNSAFE"; + +# Try another OAuth connection +my ($stdout, $stderr) = ('', ''); +$node->psql( + 'postgres', + "SELECT 'This should sleep'; SELECT pg_sleep(3); SELECT 'This should never run';", + extra_params => ['--set', "ON_ERROR_STOP=1"], + env => { PGOAUTHDEBUG => "UNSAFE" }, + timeout => $PostgreSQL::Test::Utils::timeout_default, + connstr => "user=test oauth_issuer=$issuer oauth_client_id=f02c6361-0635", + stdout => \$stdout, + stderr => \$stderr +); + +# Look for token expiration errors in the logs +my $expiry_logged = $node->wait_for_log(qr/session expired: OAuth token is no longer valid/, $log_start); +ok($expiry_logged, "Token expiration message found in server logs"); + +# Clean up +$bgconn->query_safe("ALTER SYSTEM RESET oauth_validator.enable_expiry_test"); +$bgconn->query_safe("ALTER SYSTEM RESET oauth_validator.token_expires_in"); + +$node->stop; diff --git a/src/test/modules/oauth_validator/validator.c b/src/test/modules/oauth_validator/validator.c index 0b983a9dc8f..36468312b82 100644 --- a/src/test/modules/oauth_validator/validator.c +++ b/src/test/modules/oauth_validator/validator.c @@ -18,6 +18,8 @@ #include "miscadmin.h" #include "utils/guc.h" #include "utils/memutils.h" +#include "utils/timestamp.h" +#include PG_MODULE_MAGIC; @@ -40,6 +42,8 @@ static const OAuthValidatorCallbacks validator_callbacks = { /* GUCs */ static char *authn_id = NULL; static bool authorize_tokens = true; +static bool enable_expiry_test = false; +static int token_expires_in = 0; /*--- * Extension entry point. Sets up GUCs for use by tests: @@ -72,6 +76,25 @@ _PG_init(void) 0, NULL, NULL, NULL); + /* Parameters for token expiration testing */ + DefineCustomBoolVariable("oauth_validator.enable_expiry_test", + "Enable token expiration testing", + NULL, + &enable_expiry_test, + false, + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + DefineCustomIntVariable("oauth_validator.token_expires_in", + "Token lifetime in seconds for expiry test", + NULL, + &token_expires_in, + 0, + 0, INT_MAX, + PGC_SIGHUP, + 0, + NULL, NULL, NULL); + MarkGUCPrefixReserved("oauth_validator"); } @@ -114,6 +137,15 @@ validator_shutdown(ValidatorModuleState *state) state->private_data); } +/* + * Test callback function for token expiration checking + */ +static bool +test_token_expired_callback(TimestampTz expiry) +{ + return (TimestampTzPlusMilliseconds(expiry, 0) < GetCurrentTimestamp()); +} + /* * Validator implementation. Logs the incoming data and authorizes the token by * default; the behavior can be modified via the module's GUC settings. @@ -139,5 +171,26 @@ validate_token(const ValidatorModuleState *state, else res->authn_id = pstrdup(role); + /* Set up expiration data if testing is enabled */ + if (enable_expiry_test) + { + TimestampTz now = GetCurrentTimestamp(); + + /* Set the callback if enabled */ + res->expired_cb = test_token_expired_callback; + + /* Add token_expires_in seconds to current time for expiry */ + if (token_expires_in > 0) + { + res->expiry = TimestampTzPlusSeconds(now, token_expires_in); + } + else + { + /* Use a far future time */ + res->expiry = DT_NOEND; + } + + } + return true; }