From a79257394b3766e676c02543d7f416217f05d293 Mon Sep 17 00:00:00 2001
From: Jacob Champion <jacob.champion@enterprisedb.com>
Date: Fri, 22 Aug 2025 17:39:40 -0700
Subject: [PATCH v3 04/10] WIP: pytest: Add some server-side SSL tests

In the same vein as the previous commit, this is a server-only test
suite operating against a mock client. The test itself is a heavily
parameterized check for direct-SSL handshake behavior, using a
combination of "standard" and "custom" certificates via the certs
fixture.

installcheck is currently unsupported, but the architecture has some
extension points that should make it possible later. For now, a new
server is always started for the test session.

New session-level fixtures have been added which probably need to
migrate to the `pg` package. Of note:

- datadir points to the server's data directory
- sockdir points to the server's UNIX socket/lock directory
- server_instance actually inits and starts a server via the pg_ctl on
  PATH (and could eventually point at an installcheck target)

Wrapping these session-level fixtures is pg_server[_session], which
provides APIs for configuration changes that unwind themselves at the
end of fixture scopes. There's also an example of nested scopes, via
pg_server_session.subcontext(). Many TODOs remain before we're on par
with Test::Cluster, but this should illustrate my desired architecture
pretty well.

Windows currently uses SCRAM-over-UNIX for the admin account rather than
SSPI-over-TCP. There's some dead Win32 code in pg.current_windows_user,
but I've kept it as an illustration of how a developer might write such
code for SSPI. I'll probably remove it in a future patch version.

TODOs:
- port more server configuration behavior from PostgreSQL::Test::Cluster
- decide again on "session" vs. "module" scope for server fixtures
- improve remaining_timeout() integration with socket operations; at the
  moment, the timeout resets on every call rather than decrementing
---
 src/test/pytest/pg/__init__.py  |   1 +
 src/test/pytest/pg/_win32.py    | 145 +++++++++
 src/test/ssl/pyt/conftest.py    | 113 +++++++
 src/test/ssl/pyt/test_server.py | 538 ++++++++++++++++++++++++++++++++
 4 files changed, 797 insertions(+)
 create mode 100644 src/test/pytest/pg/_win32.py
 create mode 100644 src/test/ssl/pyt/test_server.py

diff --git a/src/test/pytest/pg/__init__.py b/src/test/pytest/pg/__init__.py
index ef8faf54ca4..5dae49b6406 100644
--- a/src/test/pytest/pg/__init__.py
+++ b/src/test/pytest/pg/__init__.py
@@ -1,3 +1,4 @@
 # Copyright (c) 2025, PostgreSQL Global Development Group
 
 from ._env import has_test_extra, require_test_extra
+from ._win32 import current_windows_user
diff --git a/src/test/pytest/pg/_win32.py b/src/test/pytest/pg/_win32.py
new file mode 100644
index 00000000000..3fd67b10191
--- /dev/null
+++ b/src/test/pytest/pg/_win32.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+import ctypes
+import platform
+
+
+def current_windows_user():
+    """
+    A port of pg_regress.c's current_windows_user() helper. Returns
+    (accountname, domainname).
+
+    XXX This is dead code now, but I'm keeping it as a motivating example of
+    Win32 interaction, and someone may find it useful in the future when writing
+    SSPI tests?
+    """
+    try:
+        advapi32 = ctypes.windll.advapi32
+        kernel32 = ctypes.windll.kernel32
+    except AttributeError:
+        raise RuntimeError(
+            f"current_windows_user() is not supported on {platform.system()}"
+        )
+
+    def raise_winerror_when_false(result, func, arguments):
+        """
+        A ctypes errcheck handler that raises WinError (which will contain the
+        result of GetLastError()) when the function's return value is false.
+        """
+        if not result:
+            raise ctypes.WinError()
+
+    #
+    # Function Prototypes
+    #
+
+    from ctypes import wintypes
+
+    # GetCurrentProcess
+    kernel32.GetCurrentProcess.restype = wintypes.HANDLE
+    kernel32.GetCurrentProcess.argtypes = []
+
+    # OpenProcessToken
+    TOKEN_READ = 0x00020008
+
+    advapi32.OpenProcessToken.restype = wintypes.BOOL
+    advapi32.OpenProcessToken.argtypes = [
+        wintypes.HANDLE,
+        wintypes.DWORD,
+        wintypes.PHANDLE,
+    ]
+    advapi32.OpenProcessToken.errcheck = raise_winerror_when_false
+
+    # GetTokenInformation
+    PSID = wintypes.LPVOID  # we don't need the internals
+    TOKEN_INFORMATION_CLASS = wintypes.INT
+    TokenUser = 1
+
+    class SID_AND_ATTRIBUTES(ctypes.Structure):
+        _fields_ = [
+            ("Sid", PSID),
+            ("Attributes", wintypes.DWORD),
+        ]
+
+    class TOKEN_USER(ctypes.Structure):
+        _fields_ = [
+            ("User", SID_AND_ATTRIBUTES),
+        ]
+
+    advapi32.GetTokenInformation.restype = wintypes.BOOL
+    advapi32.GetTokenInformation.argtypes = [
+        wintypes.HANDLE,
+        TOKEN_INFORMATION_CLASS,
+        wintypes.LPVOID,
+        wintypes.DWORD,
+        wintypes.PDWORD,
+    ]
+    advapi32.GetTokenInformation.errcheck = raise_winerror_when_false
+
+    # LookupAccountSid
+    SID_NAME_USE = wintypes.INT
+    PSID_NAME_USE = ctypes.POINTER(SID_NAME_USE)
+
+    advapi32.LookupAccountSidW.restype = wintypes.BOOL
+    advapi32.LookupAccountSidW.argtypes = [
+        wintypes.LPCWSTR,
+        PSID,
+        wintypes.LPWSTR,
+        wintypes.LPDWORD,
+        wintypes.LPWSTR,
+        wintypes.LPDWORD,
+        PSID_NAME_USE,
+    ]
+    advapi32.LookupAccountSidW.errcheck = raise_winerror_when_false
+
+    #
+    # Implementation (see pg_SSPI_recv_auth())
+    #
+
+    # Get the current process token...
+    token = wintypes.HANDLE()
+    proc = kernel32.GetCurrentProcess()
+    advapi32.OpenProcessToken(proc, TOKEN_READ, token)
+
+    # ...then read the TOKEN_USER struct for that token...
+    info = TOKEN_USER()
+    infolen = wintypes.DWORD()
+
+    try:
+        # (GetTokenInformation creates a buffer bigger than TOKEN_USER, so we
+        # have to query the correct length first.)
+        advapi32.GetTokenInformation(token, TokenUser, None, 0, ctypes.byref(infolen))
+        assert False, "GetTokenInformation succeeded unexpectedly"
+
+    except OSError as err:
+        assert err.winerror == 122  # insufficient buffer
+
+        ctypes.resize(info, infolen.value)
+        advapi32.GetTokenInformation(
+            token,
+            TokenUser,
+            ctypes.byref(info),
+            ctypes.sizeof(info),
+            ctypes.byref(infolen),
+        )
+
+    # ...then pull the account and domain names out of the user SID.
+    MAXPGPATH = 1024
+
+    account = ctypes.create_unicode_buffer(MAXPGPATH)
+    domain = ctypes.create_unicode_buffer(MAXPGPATH)
+    accountlen = wintypes.DWORD(ctypes.sizeof(account))
+    domainlen = wintypes.DWORD(ctypes.sizeof(domain))
+    use = SID_NAME_USE()
+
+    advapi32.LookupAccountSidW(
+        None,
+        info.User.Sid,
+        account,
+        ctypes.byref(accountlen),
+        domain,
+        ctypes.byref(domainlen),
+        ctypes.byref(use),
+    )
+
+    return (account.value, domain.value)
diff --git a/src/test/ssl/pyt/conftest.py b/src/test/ssl/pyt/conftest.py
index fb4db372f03..85d2c994828 100644
--- a/src/test/ssl/pyt/conftest.py
+++ b/src/test/ssl/pyt/conftest.py
@@ -1,6 +1,12 @@
 # Copyright (c) 2025, PostgreSQL Global Development Group
 
 import datetime
+import os
+import pathlib
+import platform
+import secrets
+import socket
+import subprocess
 import tempfile
 from collections import namedtuple
 
@@ -127,3 +133,110 @@ def certs(cryptography, tmp_path_factory):
             return f.name
 
     return _Certs()
+
+
+@pytest.fixture(scope="session")
+def datadir(tmp_path_factory):
+    """
+    Returns the directory name to use as the server data directory. If
+    TESTDATADIR is provided, that will be used; otherwise a new temporary
+    directory is created in the pytest temp root.
+    """
+    d = os.getenv("TESTDATADIR")
+    if d:
+        d = pathlib.Path(d)
+    else:
+        d = tmp_path_factory.mktemp("tmp_check")
+
+    return d
+
+
+@pytest.fixture(scope="session")
+def sockdir(tmp_path_factory):
+    """
+    Returns the directory name to use as the server's unix_socket_directories
+    setting. Local client connections use this as the PGHOST.
+
+    At the moment, this is always put under the pytest temp root.
+    """
+    return tmp_path_factory.mktemp("sockfiles")
+
+
+@pytest.fixture(scope="session")
+def winpassword():
+    """The per-session SCRAM password for the server admin on Windows."""
+    return secrets.token_urlsafe(16)
+
+
+@pytest.fixture(scope="session")
+def server_instance(certs, datadir, sockdir, winpassword):
+    """
+    Starts a running Postgres server listening on localhost. The HBA initially
+    allows only local UNIX connections from the same user.
+
+    TODO: when installcheck is supported, this should optionally point to the
+    currently running server instead.
+    """
+
+    # Lock down the HBA by default; tests can open it back up later.
+    if platform.system() == "Windows":
+        # On Windows, for admin connections, use SCRAM with a generated password
+        # over local sockets. This requires additional work during initdb.
+        method = "scram-sha-256"
+
+        # NamedTemporaryFile doesn't work very nicely on Windows until Python
+        # 3.12, which introduces NamedTemporaryFile(delete_on_close=False).
+        # Until then, specify delete=False and manually unlink after use.
+        with tempfile.NamedTemporaryFile("w", delete=False) as pwfile:
+            pwfile.write(winpassword)
+
+        subprocess.check_call(
+            ["initdb", "--auth=scram-sha-256", "--pwfile", pwfile.name, datadir]
+        )
+        os.unlink(pwfile.name)
+
+    else:
+        # For other OSes we can just use peer auth.
+        method = "peer"
+        subprocess.check_call(["pg_ctl", "-D", datadir, "init"])
+
+    with open(datadir / "pg_hba.conf", "w") as f:
+        print(f"# default: local {method} connections only", file=f)
+        print(f"local all all {method}", file=f)
+
+    # Figure out a port to listen on. Attempt to reserve both IPv4 and IPv6
+    # addresses in one go.
+    #
+    # Note: socket.has_dualstack_ipv6/create_server are only in Python 3.8+.
+    if hasattr(socket, "has_dualstack_ipv6") and socket.has_dualstack_ipv6():
+        addr = ("::1", 0)
+        s = socket.create_server(addr, family=socket.AF_INET6, dualstack_ipv6=True)
+
+        hostaddr, port, _, _ = s.getsockname()
+        addrs = [hostaddr, "127.0.0.1"]
+
+    else:
+        addr = ("127.0.0.1", 0)
+
+        s = socket.socket()
+        s.bind(addr)
+
+        hostaddr, port = s.getsockname()
+        addrs = [hostaddr]
+
+    log = os.path.join(datadir, "postgresql.log")
+
+    with s, open(os.path.join(datadir, "postgresql.conf"), "a") as f:
+        print(file=f)
+        print("unix_socket_directories = '{}'".format(sockdir.as_posix()), file=f)
+        print("listen_addresses = '{}'".format(",".join(addrs)), file=f)
+        print("port =", port, file=f)
+        print("log_connections = all", file=f)
+
+    # Between closing of the socket, s, and server start, we're racing against
+    # anything that wants to open up ephemeral ports, so try not to put any new
+    # work here.
+
+    subprocess.check_call(["pg_ctl", "-D", datadir, "-l", log, "start"])
+    yield (hostaddr, port)
+    subprocess.check_call(["pg_ctl", "-D", datadir, "-l", log, "stop"])
diff --git a/src/test/ssl/pyt/test_server.py b/src/test/ssl/pyt/test_server.py
new file mode 100644
index 00000000000..2d0be735371
--- /dev/null
+++ b/src/test/ssl/pyt/test_server.py
@@ -0,0 +1,538 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+import contextlib
+import os
+import pathlib
+import platform
+import re
+import shutil
+import socket
+import ssl
+import struct
+import subprocess
+import tempfile
+from collections import namedtuple
+from typing import Dict, List, Union
+
+import pytest
+
+import pg
+
+# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl.
+pytestmark = pg.require_test_extra("ssl")
+
+
+#
+# Test Fixtures
+#
+
+
+@pytest.fixture(scope="session")
+def connenv(server_instance, sockdir, datadir):
+    """
+    Provides the values for several PG* environment variables needed for our
+    utility programs to connect to the server_instance.
+    """
+    return {
+        "PGHOST": str(sockdir),
+        "PGPORT": str(server_instance[1]),
+        "PGDATABASE": "postgres",
+        "PGDATA": str(datadir),
+    }
+
+
+class FileBackup(contextlib.AbstractContextManager):
+    """
+    A context manager which backs up a file's contents, restoring them on exit.
+    """
+
+    def __init__(self, file: pathlib.Path):
+        super().__init__()
+
+        self._file = file
+
+    def __enter__(self):
+        with tempfile.NamedTemporaryFile(
+            prefix=self._file.name, dir=self._file.parent, delete=False
+        ) as f:
+            self._backup = pathlib.Path(f.name)
+
+        shutil.copyfile(self._file, self._backup)
+
+        return self
+
+    def __exit__(self, *exc):
+        # Swap the backup and the original file, so that the modified contents
+        # can still be inspected in case of failure.
+        #
+        # TODO: this is less helpful if there are multiple layers, because it's
+        # not clear which backup to look at. Can the backup name be printed as
+        # part of the failed test output? Should we only swap on test failure?
+        tmp = self._backup.parent / (self._backup.name + ".tmp")
+
+        shutil.copyfile(self._file, tmp)
+        shutil.copyfile(self._backup, self._file)
+        shutil.move(tmp, self._backup)
+
+
+class HBA(FileBackup):
+    """
+    Backs up a server's HBA configuration and provides means for temporarily
+    editing it. See also pg_server, which provides an instance of this class and
+    context managers for enforcing the reload/restart order of operations.
+    """
+
+    def __init__(self, datadir: pathlib.Path):
+        super().__init__(datadir / "pg_hba.conf")
+
+    def prepend(self, *lines: Union[str, List[str]]):
+        """
+        Temporarily prepends lines to the server's pg_hba.conf.
+
+        As sugar for aligning HBA columns in the tests, each line can be either
+        a string or a list of strings. List elements will be joined by single
+        spaces before they are written to file.
+        """
+        with open(self._file, "r") as f:
+            prior_data = f.read()
+
+        with open(self._file, "w") as f:
+            for l in lines:
+                if isinstance(l, list):
+                    print(*l, file=f)
+                else:
+                    print(l, file=f)
+
+            f.write(prior_data)
+
+
+class Config(FileBackup):
+    """
+    Backs up a server's postgresql.conf and provides means for temporarily
+    editing it. See also pg_server, which provides an instance of this class and
+    context managers for enforcing the reload/restart order of operations.
+    """
+
+    def __init__(self, datadir: pathlib.Path):
+        super().__init__(datadir / "postgresql.conf")
+
+    def set(self, **gucs):
+        """
+        Temporarily appends GUC settings to the server's postgresql.conf.
+        """
+
+        with open(self._file, "a") as f:
+            print(file=f)
+
+            for n, v in gucs.items():
+                v = str(v)
+
+                # TODO: proper quoting
+                v = v.replace("\\", "\\\\")
+                v = v.replace("'", "\\'")
+                v = "'{}'".format(v)
+
+                print(n, "=", v, file=f)
+
+
+@pytest.fixture(scope="session")
+def pg_server_session(server_instance, connenv, datadir, winpassword):
+    """
+    Provides common routines for configuring and connecting to the
+    server_instance. For example:
+
+        users = pg_server_session.create_users("one", "two")
+        dbs = pg_server_session.create_dbs("default")
+
+        with pg_server_session.reloading() as s:
+            s.hba.prepend(["local", dbs["default"], users["two"], "peer"])
+
+        conn = connect_somehow(**pg_server_session.conninfo)
+        ...
+
+    Attributes of note are
+    - .conninfo: provides TCP connection info for the server
+
+    This fixture unwinds its configuration changes at the end of the pytest
+    session. For more granular changes, pg_server_session.subcontext() splits
+    off a "nested" context to allow smaller scopes.
+    """
+
+    class _Server(contextlib.ExitStack):
+        conninfo = dict(
+            hostaddr=server_instance[0],
+            port=server_instance[1],
+        )
+
+        # for _backup_configuration()
+        _Backup = namedtuple("Backup", "conf, hba")
+
+        def subcontext(self):
+            """
+            Creates a new server stack instance that can be tied to a smaller
+            scope than "session".
+            """
+            # So far, there doesn't seem to be a need to link the two objects,
+            # since HBA/Config/FileBackup operate directly on the filesystem and
+            # will appear to "nest" naturally.
+            return self.__class__()
+
+        def create_users(self, *userkeys: str) -> Dict[str, str]:
+            """
+            Creates new users which will be dropped at the end of the server
+            context.
+
+            For each provided key, a related user name will be selected and
+            stored in a map. This map is returned to let calling code look up
+            the selected usernames (instead of hardcoding them and potentially
+            stomping on an existing installation).
+            """
+            usermap = {}
+
+            for u in userkeys:
+                # TODO: use a uniquifier to support installcheck
+                name = u + "user"
+                usermap[u] = name
+
+                # TODO: proper escaping
+                self.psql("-c", "CREATE USER " + name)
+                self.callback(self.psql, "-c", "DROP USER " + name)
+
+            return usermap
+
+        def create_dbs(self, *dbkeys: str) -> Dict[str, str]:
+            """
+            Creates new databases which will be dropped at the end of the server
+            context. See create_users() for the meaning of the keys and returned
+            map.
+            """
+            dbmap = {}
+
+            for d in dbkeys:
+                # TODO: use a uniquifier to support installcheck
+                name = d + "db"
+                dbmap[d] = name
+
+                # TODO: proper escaping
+                self.psql("-c", "CREATE DATABASE " + name)
+                self.callback(self.psql, "-c", "DROP DATABASE " + name)
+
+            return dbmap
+
+        @contextlib.contextmanager
+        def reloading(self):
+            """
+            Provides a context manager for making configuration changes.
+
+            If the context suite finishes successfully, the configuration will
+            be reloaded via pg_ctl. On teardown, the configuration changes will
+            be unwound, and the server will be signaled to reload again.
+
+            The context target contains the following attributes which can be
+            used to configure the server:
+            - .conf: modifies postgresql.conf
+            - .hba: modifies pg_hba.conf
+
+            For example:
+
+                with pg_server_session.reloading() as s:
+                    s.conf.set(log_connections="on")
+                    s.hba.prepend("local all all trust")
+            """
+            try:
+                # Push a reload onto the stack before making any other
+                # unwindable changes. That way the order of operations will be
+                #
+                #  # test
+                #   - config change 1
+                #   - config change 2
+                #   - reload
+                #  # teardown
+                #   - undo config change 2
+                #   - undo config change 1
+                #   - reload
+                #
+                self.callback(self.pg_ctl, "reload")
+                yield self._backup_configuration()
+            except:
+                # We only want to reload at the end of the suite if there were
+                # no errors. During exceptions, the pushed callback handles
+                # things instead, so there's nothing to do here.
+                raise
+            else:
+                # Suite completed successfully.
+                self.pg_ctl("reload")
+
+        @contextlib.contextmanager
+        def restarting(self):
+            """Like .reloading(), but with a full server restart."""
+            try:
+                self.callback(self.pg_ctl, "restart")
+                yield self._backup_configuration()
+            except:
+                raise
+            else:
+                self.pg_ctl("restart")
+
+        def psql(self, *args):
+            """
+            Runs psql with the given arguments. Password prompts are always
+            disabled. On Windows, the admin password will be included in the
+            environment.
+            """
+            if platform.system() == "Windows":
+                pw = dict(PGPASSWORD=winpassword)
+            else:
+                pw = None
+
+            self._run("psql", "-w", *args, addenv=pw)
+
+        def pg_ctl(self, *args):
+            """
+            Runs pg_ctl with the given arguments. Log output will be placed in
+            postgresql.log in the server's data directory.
+
+            TODO: put the log in TESTLOGDIR
+            """
+            self._run("pg_ctl", "-l", str(datadir / "postgresql.log"), *args)
+
+        def _run(self, cmd, *args, addenv: dict = None):
+            # Override the existing environment with the connenv values and
+            # anything the caller wanted to add. (Python 3.9 gives us the
+            # less-ugly `os.environ | connenv` merge operator.)
+            subenv = dict(os.environ, **connenv)
+            if addenv:
+                subenv.update(addenv)
+
+            subprocess.check_call([cmd, *args], env=subenv)
+
+        def _backup_configuration(self):
+            # Wrap the existing HBA and configuration with FileBackups.
+            return self._Backup(
+                hba=self.enter_context(HBA(datadir)),
+                conf=self.enter_context(Config(datadir)),
+            )
+
+    with _Server() as s:
+        yield s
+
+
+@pytest.fixture(scope="module", autouse=True)
+def ssl_setup(pg_server_session, certs, datadir):
+    """
+    Sets up required server settings for all tests in this module. The fixture
+    variable is a tuple (users, dbs) containing the user and database names that
+    have been chosen for the test session.
+    """
+    try:
+        with pg_server_session.restarting() as s:
+            s.conf.set(
+                ssl="on",
+                ssl_ca_file=certs.ca.certpath,
+                ssl_cert_file=certs.server.certpath,
+                ssl_key_file=certs.server.keypath,
+            )
+
+            # Reject by default.
+            s.hba.prepend("hostssl all all all reject")
+
+    except subprocess.CalledProcessError:
+        # This is a decent place to skip if the server isn't set up for SSL.
+        logpath = datadir / "postgresql.log"
+        unsupported = re.compile("SSL is not supported")
+
+        with open(logpath, "r") as log:
+            for line in log:
+                if unsupported.search(line):
+                    pytest.skip("the server does not support SSL")
+
+        # Some other error happened.
+        raise
+
+    users = pg_server_session.create_users(
+        "ssl",
+    )
+
+    dbs = pg_server_session.create_dbs(
+        "ssl",
+    )
+
+    return (users, dbs)
+
+
+@pytest.fixture(scope="module")
+def client_cert(ssl_setup, certs):
+    """
+    Creates a Cert for the "ssl" user.
+    """
+    from cryptography import x509
+    from cryptography.x509.oid import NameOID
+
+    users, _ = ssl_setup
+    user = users["ssl"]
+
+    return certs.new(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, user)]))
+
+
+@pytest.fixture
+def pg_server(pg_server_session):
+    """
+    A per-test instance of pg_server_session. Use this fixture to make changes
+    to the server which will be rolled back at the end of every test.
+    """
+    with pg_server_session.subcontext() as s:
+        yield s
+
+
+#
+# Tests
+#
+
+
+# For use with the `creds` parameter below.
+CLIENT = "client"
+SERVER = "server"
+
+
+@pytest.mark.parametrize(
+    # fmt: off
+    "auth_method,                    creds,  expected_error",
+[
+    # Trust allows anything.
+    ("trust",                        None,   None),
+    ("trust",                        CLIENT, None),
+    ("trust",                        SERVER, None),
+
+    # verify-ca allows any CA-signed certificate.
+    ("trust clientcert=verify-ca",   None,   "requires a valid client certificate"),
+    ("trust clientcert=verify-ca",   CLIENT, None),
+    ("trust clientcert=verify-ca",   SERVER, None),
+
+    # cert and verify-full allow only the correct certificate.
+    ("trust clientcert=verify-full", None,   "requires a valid client certificate"),
+    ("trust clientcert=verify-full", CLIENT, None),
+    ("trust clientcert=verify-full", SERVER, "authentication failed for user"),
+    ("cert",                         None,   "requires a valid client certificate"),
+    ("cert",                         CLIENT, None),
+    ("cert",                         SERVER, "authentication failed for user"),
+],
+    # fmt: on
+)
+def test_direct_ssl_certificate_authentication(
+    pg_server,
+    ssl_setup,
+    certs,
+    client_cert,
+    remaining_timeout,
+    # test parameters
+    auth_method,
+    creds,
+    expected_error,
+):
+    """
+    Tests direct SSL connections with various client-certificate/HBA
+    combinations.
+    """
+
+    # Set up the HBA as desired by the test.
+    users, dbs = ssl_setup
+
+    user = users["ssl"]
+    db = dbs["ssl"]
+
+    with pg_server.reloading() as s:
+        s.hba.prepend(
+            ["hostssl", db, user, "127.0.0.1/32", auth_method],
+            ["hostssl", db, user, "::1/128", auth_method],
+        )
+
+    # Configure the SSL settings for the client.
+    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+    ctx.load_verify_locations(cafile=certs.ca.certpath)
+    ctx.set_alpn_protocols(["postgresql"])  # for direct SSL
+
+    # Load up a client certificate if required by the test.
+    if creds == CLIENT:
+        ctx.load_cert_chain(client_cert.certpath, client_cert.keypath)
+    elif creds == SERVER:
+        # Using a server certificate as the client credential is expected to
+        # work only for clientcert=verify-ca (and `trust`, naturally).
+        ctx.load_cert_chain(certs.server.certpath, certs.server.keypath)
+
+    # Make a direct SSL connection. There's no SSLRequest in the handshake; we
+    # simply wrap a TCP connection with OpenSSL.
+    addr = (pg_server.conninfo["hostaddr"], pg_server.conninfo["port"])
+    with socket.create_connection(addr) as s:
+        s.settimeout(remaining_timeout())  # XXX this resets every operation
+
+        with ctx.wrap_socket(s, server_hostname=certs.server_host) as conn:
+            # Build and send the startup packet.
+            startup_options = dict(
+                user=user,
+                database=db,
+                application_name="pytest",
+            )
+
+            payload = b""
+            for k, v in startup_options.items():
+                payload += k.encode() + b"\0"
+                payload += str(v).encode() + b"\0"
+            payload += b"\0"  # null terminator
+
+            pktlen = 4 + 4 + len(payload)
+            conn.send(struct.pack("!IHH", pktlen, 3, 0) + payload)
+
+            if not expected_error:
+                # Expect an AuthenticationOK to come back.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"R"
+                assert pktlen == 8
+
+                authn_result = struct.unpack("!I", conn.recv(4))[0]
+                assert authn_result == 0
+
+                # Read and discard to ReadyForQuery.
+                while True:
+                    pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                    payload = conn.recv(pktlen - 4)
+
+                    if pkttype == b"Z":
+                        assert payload == b"I"
+                        break
+
+                # Send an empty query.
+                conn.send(struct.pack("!cI", b"Q", 5) + b"\0")
+
+                # Expect EmptyQueryResponse+ReadyForQuery.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"I"
+                assert pktlen == 4
+
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"Z"
+
+                payload = conn.recv(pktlen - 4)
+                assert payload == b"I"
+
+            else:
+                # Match the expected authentication error.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"E"
+
+                payload = conn.recv(pktlen - 4)
+                msg = None
+
+                for component in payload.split(b"\0"):
+                    if not component:
+                        break  # end of message
+
+                    key, val = component[:1], component[1:]
+                    if key == b"S":
+                        assert val == b"FATAL"
+                    elif key == b"M":
+                        msg = val.decode()
+
+                assert re.search(expected_error, msg), "server error did not match"
+
+            # Terminate.
+            conn.send(struct.pack("!cI", b"X", 4))
-- 
2.51.1

