From c52891010df78694377e5d9d6cc6c115645b3ad1 Mon Sep 17 00:00:00 2001
From: "okbob@github.com" <okbob@github.com>
Date: Sat, 20 Jan 2024 07:57:29 +0100
Subject: [PATCH 15/20] allow parallel execution queries with session variables

---
 src/backend/executor/execMain.c               |  14 +-
 src/backend/executor/execParallel.c           | 147 +++++++++++++++++-
 src/backend/optimizer/util/clauses.c          |  18 +--
 .../regress/expected/session_variables.out    |  12 +-
 4 files changed, 171 insertions(+), 20 deletions(-)

diff --git a/src/backend/executor/execMain.c b/src/backend/executor/execMain.c
index f2fb680defc..4fbe7c71840 100644
--- a/src/backend/executor/execMain.c
+++ b/src/backend/executor/execMain.c
@@ -203,7 +203,19 @@ standard_ExecutorStart(QueryDesc *queryDesc, int eflags)
 	 * related session variables are copied to dedicated array, and this array
 	 * is passed to executor.
 	 */
-	if (queryDesc->plannedstmt->sessionVariables)
+	if (queryDesc->num_session_variables > 0)
+	{
+		/*
+		 * When a parallel query needs to access query parameters (including
+		 * related session variables), then related session variables are
+		 * restored (deserialized) in queryDesc already. So just push pointer
+		 * of this array to executor's estate.
+		 */
+		Assert(IsParallelWorker());
+		estate->es_session_variables = queryDesc->session_variables;
+		estate->es_num_session_variables = queryDesc->num_session_variables;
+	}
+	else if (queryDesc->plannedstmt->sessionVariables)
 	{
 		ListCell   *lc;
 		int			nSessionVariables;
diff --git a/src/backend/executor/execParallel.c b/src/backend/executor/execParallel.c
index bfb3419efb7..cb434e2768e 100644
--- a/src/backend/executor/execParallel.c
+++ b/src/backend/executor/execParallel.c
@@ -12,8 +12,9 @@
  * workers and ensuring that their state generally matches that of the
  * leader; see src/backend/access/transam/README.parallel for details.
  * However, we must save and restore relevant executor state, such as
- * any ParamListInfo associated with the query, buffer/WAL usage info, and
- * the actual plan to be passed down to the worker.
+ * any ParamListInfo associated with the query, buffer/WAL usage info,
+ * session variables buffer, and the actual plan to be passed down to
+ * the worker.
  *
  * IDENTIFICATION
  *	  src/backend/executor/execParallel.c
@@ -64,6 +65,7 @@
 #define PARALLEL_KEY_QUERY_TEXT		UINT64CONST(0xE000000000000008)
 #define PARALLEL_KEY_JIT_INSTRUMENTATION UINT64CONST(0xE000000000000009)
 #define PARALLEL_KEY_WAL_USAGE			UINT64CONST(0xE00000000000000A)
+#define PARALLEL_KEY_SESSION_VARIABLES	UINT64CONST(0xE00000000000000B)
 
 #define PARALLEL_TUPLE_QUEUE_SIZE		65536
 
@@ -138,6 +140,12 @@ static bool ExecParallelRetrieveInstrumentation(PlanState *planstate,
 /* Helper function that runs in the parallel worker. */
 static DestReceiver *ExecParallelGetReceiver(dsm_segment *seg, shm_toc *toc);
 
+/* Helper functions that can pass values of session variables */
+static Size EstimateSessionVariables(EState *estate);
+static void SerializeSessionVariables(EState *estate, char **start_address);
+static SessionVariableValue *RestoreSessionVariables(char **start_address,
+													 int *num_session_variables);
+
 /*
  * Create a serialized representation of the plan to be sent to each worker.
  */
@@ -596,6 +604,7 @@ ExecInitParallelPlan(PlanState *planstate, EState *estate,
 	char	   *pstmt_data;
 	char	   *pstmt_space;
 	char	   *paramlistinfo_space;
+	char	   *session_variables_space;
 	BufferUsage *bufusage_space;
 	WalUsage   *walusage_space;
 	SharedExecutorInstrumentation *instrumentation = NULL;
@@ -605,6 +614,7 @@ ExecInitParallelPlan(PlanState *planstate, EState *estate,
 	int			instrumentation_len = 0;
 	int			jit_instrumentation_len = 0;
 	int			instrument_offset = 0;
+	int			session_variables_len = 0;
 	Size		dsa_minsize = dsa_minimum_size();
 	char	   *query_string;
 	int			query_len;
@@ -660,6 +670,11 @@ ExecInitParallelPlan(PlanState *planstate, EState *estate,
 	shm_toc_estimate_chunk(&pcxt->estimator, paramlistinfo_len);
 	shm_toc_estimate_keys(&pcxt->estimator, 1);
 
+	/* Estimate space for serialized session variables. */
+	session_variables_len = EstimateSessionVariables(estate);
+	shm_toc_estimate_chunk(&pcxt->estimator, session_variables_len);
+	shm_toc_estimate_keys(&pcxt->estimator, 1);
+
 	/*
 	 * Estimate space for BufferUsage.
 	 *
@@ -761,6 +776,11 @@ ExecInitParallelPlan(PlanState *planstate, EState *estate,
 	shm_toc_insert(pcxt->toc, PARALLEL_KEY_PARAMLISTINFO, paramlistinfo_space);
 	SerializeParamList(estate->es_param_list_info, &paramlistinfo_space);
 
+	/* Store serialized session variables. */
+	session_variables_space = shm_toc_allocate(pcxt->toc, session_variables_len);
+	shm_toc_insert(pcxt->toc, PARALLEL_KEY_SESSION_VARIABLES, session_variables_space);
+	SerializeSessionVariables(estate, &session_variables_space);
+
 	/* Allocate space for each worker's BufferUsage; no need to initialize. */
 	bufusage_space = shm_toc_allocate(pcxt->toc,
 									  mul_size(sizeof(BufferUsage), pcxt->nworkers));
@@ -1411,6 +1431,7 @@ ParallelQueryMain(dsm_segment *seg, shm_toc *toc)
 	SharedJitInstrumentation *jit_instrumentation;
 	int			instrument_options = 0;
 	void	   *area_space;
+	char	   *sessionvariable_space;
 	dsa_area   *area;
 	ParallelWorkerContext pwcxt;
 
@@ -1436,6 +1457,14 @@ ParallelQueryMain(dsm_segment *seg, shm_toc *toc)
 	area_space = shm_toc_lookup(toc, PARALLEL_KEY_DSA, false);
 	area = dsa_attach_in_place(area_space, seg);
 
+	/* Reconstruct session variables. */
+	sessionvariable_space = shm_toc_lookup(toc,
+										   PARALLEL_KEY_SESSION_VARIABLES,
+										   false);
+	queryDesc->session_variables =
+		RestoreSessionVariables(&sessionvariable_space,
+								&queryDesc->num_session_variables);
+
 	/* Start up the executor */
 	queryDesc->plannedstmt->jitFlags = fpes->jit_flags;
 	ExecutorStart(queryDesc, fpes->eflags);
@@ -1504,3 +1533,117 @@ ParallelQueryMain(dsm_segment *seg, shm_toc *toc)
 	FreeQueryDesc(queryDesc);
 	receiver->rDestroy(receiver);
 }
+
+/*
+ * Estimate the amount of space required to serialize a session variable.
+ */
+static Size
+EstimateSessionVariables(EState *estate)
+{
+	int			i;
+	Size		sz = sizeof(int);
+
+	if (estate->es_session_variables == NULL)
+		return sz;
+
+	for (i = 0; i < estate->es_num_session_variables; i++)
+	{
+		SessionVariableValue *svarval;
+		Oid			typeOid;
+		int16		typLen;
+		bool		typByVal;
+
+		svarval = &estate->es_session_variables[i];
+
+		typeOid = svarval->typid;
+
+		sz = add_size(sz, sizeof(Oid)); /* space for type OID */
+
+		/* space for datum/isnull */
+		Assert(OidIsValid(typeOid));
+		get_typlenbyval(typeOid, &typLen, &typByVal);
+
+		sz = add_size(sz,
+					  datumEstimateSpace(svarval->value, svarval->isnull, typByVal, typLen));
+	}
+
+	return sz;
+}
+
+/*
+ * Serialize a session variables buffer into caller-provided storage.
+ *
+ * We write the number of parameters first, as a 4-byte integer, and then
+ * write details for each parameter in turn.  The details for each parameter
+ * consist of a 4-byte type OID, and then the datum as serialized by
+ * datumSerialize().  The caller is responsible for ensuring that there is
+ * enough storage to store the number of bytes that will be written; use
+ * EstimateSessionVariables to find out how many will be needed.
+ * *start_address is updated to point to the byte immediately following those
+ * written.
+ *
+ * RestoreSessionVariables can be used to recreate a session variable buffer
+ * based on the serialized representation;
+ */
+static void
+SerializeSessionVariables(EState *estate, char **start_address)
+{
+	int			nparams;
+	int			i;
+
+	/* Write number of parameters. */
+	nparams = estate->es_num_session_variables;
+	memcpy(*start_address, &nparams, sizeof(int));
+	*start_address += sizeof(int);
+
+	/* Write each parameter in turn. */
+	for (i = 0; i < nparams; i++)
+	{
+		SessionVariableValue *svarval;
+		Oid			typeOid;
+		int16		typLen;
+		bool		typByVal;
+
+		svarval = &estate->es_session_variables[i];
+		typeOid = svarval->typid;
+
+		/* Write type OID. */
+		memcpy(*start_address, &typeOid, sizeof(Oid));
+		*start_address += sizeof(Oid);
+
+		Assert(OidIsValid(typeOid));
+		get_typlenbyval(typeOid, &typLen, &typByVal);
+
+		datumSerialize(svarval->value, svarval->isnull, typByVal, typLen,
+					   start_address);
+	}
+}
+
+static SessionVariableValue *
+RestoreSessionVariables(char **start_address, int *num_session_variables)
+{
+	SessionVariableValue *session_variables;
+	int			i;
+	int			nparams;
+
+	memcpy(&nparams, *start_address, sizeof(int));
+	*start_address += sizeof(int);
+
+	*num_session_variables = nparams;
+	session_variables = (SessionVariableValue *)
+		palloc(nparams * sizeof(SessionVariableValue));
+
+	for (i = 0; i < nparams; i++)
+	{
+		SessionVariableValue *svarval = &session_variables[i];
+
+		/* Read type OID. */
+		memcpy(&svarval->typid, *start_address, sizeof(Oid));
+		*start_address += sizeof(Oid);
+
+		/* Read datum/isnull. */
+		svarval->value = datumRestore(start_address, &svarval->isnull);
+	}
+
+	return session_variables;
+}
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index 0063b2b6322..092ed2353ac 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -924,25 +924,19 @@ max_parallel_hazard_walker(Node *node, max_parallel_hazard_context *context)
 
 	/*
 	 * We can't pass Params to workers at the moment either, so they are also
-	 * parallel-restricted, unless they are PARAM_EXTERN Params or are
-	 * PARAM_EXEC Params listed in safe_param_ids, meaning they could be
-	 * either generated within workers or can be computed by the leader and
-	 * then their value can be passed to workers.
+	 * parallel-restricted, unless they are PARAM_EXTERN or PARAM_VARIABLE
+	 * Params or are PARAM_EXEC Params listed in safe_param_ids, meaning they
+	 * could be either generated within workers or can be computed by the
+	 * leader and then their value can be passed to workers.
 	 */
 	else if (IsA(node, Param))
 	{
 		Param	   *param = (Param *) node;
 
-		if (param->paramkind == PARAM_EXTERN)
+		if (param->paramkind == PARAM_EXTERN ||
+			param->paramkind == PARAM_VARIABLE)
 			return false;
 
-		/* we don't support passing session variables to workers */
-		if (param->paramkind == PARAM_VARIABLE)
-		{
-			if (max_parallel_hazard_test(PROPARALLEL_RESTRICTED, context))
-				return true;
-		}
-
 		if (param->paramkind != PARAM_EXEC ||
 			!list_member_int(context->safe_param_ids, param->paramid))
 		{
diff --git a/src/test/regress/expected/session_variables.out b/src/test/regress/expected/session_variables.out
index d8afa7ae4e0..6466a247328 100644
--- a/src/test/regress/expected/session_variables.out
+++ b/src/test/regress/expected/session_variables.out
@@ -642,12 +642,14 @@ SELECT count(*) FROM svar_test WHERE a%10 = zero;
 
 -- parallel execution is not supported yet
 EXPLAIN (COSTS OFF) SELECT count(*) FROM svar_test WHERE a%10 = zero;
-            QUERY PLAN             
------------------------------------
+                 QUERY PLAN                 
+--------------------------------------------
  Aggregate
-   ->  Seq Scan on svar_test
-         Filter: ((a % 10) = zero)
-(3 rows)
+   ->  Gather
+         Workers Planned: 2
+         ->  Parallel Seq Scan on svar_test
+               Filter: ((a % 10) = zero)
+(5 rows)
 
 LET zero = (SELECT count(*) FROM svar_test);
 -- result should be 1000
-- 
2.45.2

