aboutsummaryrefslogtreecommitdiff
path: root/src/backend/executor/nodeAgg.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r--src/backend/executor/nodeAgg.c273
1 files changed, 261 insertions, 12 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 03aa20f61e0..aba54195a30 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -44,6 +44,12 @@
* incorrect. Instead a new state should be created in the correct aggregate
* memory context and the 2nd state should be copied over.
*
+ * The 'serialStates' option can be used to allow multi-stage aggregation
+ * for aggregates with an INTERNAL state type. When this mode is disabled
+ * only a pointer to the INTERNAL aggregate states are passed around the
+ * executor. When enabled, INTERNAL states are serialized and deserialized
+ * as required; this is useful when data must be passed between processes.
+ *
* If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the
* input tuples and eliminate duplicates (if required) before performing
* the above-depicted process. (However, we don't do that for ordered-set
@@ -232,6 +238,12 @@ typedef struct AggStatePerTransData
/* Oid of the state transition or combine function */
Oid transfn_oid;
+ /* Oid of the serialization function or InvalidOid */
+ Oid serialfn_oid;
+
+ /* Oid of the deserialization function or InvalidOid */
+ Oid deserialfn_oid;
+
/* Oid of state value's datatype */
Oid aggtranstype;
@@ -246,6 +258,12 @@ typedef struct AggStatePerTransData
*/
FmgrInfo transfn;
+ /* fmgr lookup data for serialization function */
+ FmgrInfo serialfn;
+
+ /* fmgr lookup data for deserialization function */
+ FmgrInfo deserialfn;
+
/* Input collation derived for aggregate */
Oid aggCollation;
@@ -326,6 +344,11 @@ typedef struct AggStatePerTransData
* worth the extra space consumption.
*/
FunctionCallInfoData transfn_fcinfo;
+
+ /* Likewise for serialization and deserialization functions */
+ FunctionCallInfoData serialfn_fcinfo;
+
+ FunctionCallInfoData deserialfn_fcinfo;
} AggStatePerTransData;
/*
@@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate,
AggStatePerAgg peragg,
AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull);
+static void finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull);
static void prepare_projection_slot(AggState *aggstate,
TupleTableSlot *slot,
int currentSet);
@@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggsate, EState *estate,
Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
- Datum initValue, bool initValueIsNull,
- Oid *inputTypes, int numArguments);
+ Oid aggserialtype, Oid aggserialfn,
+ Oid aggdeserialfn, Datum initValue,
+ bool initValueIsNull, Oid *inputTypes,
+ int numArguments);
static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
int lastaggno, List **same_input_transnos);
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos);
@@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
slot = ExecProject(pertrans->evalproj, NULL);
Assert(slot->tts_nvalid >= 1);
- fcinfo->arg[1] = slot->tts_values[0];
- fcinfo->argnull[1] = slot->tts_isnull[0];
+ /*
+ * deserialfn_oid will be set if we must deserialize the input state
+ * before calling the combine function
+ */
+ if (OidIsValid(pertrans->deserialfn_oid))
+ {
+ /*
+ * Don't call a strict deserialization function with NULL input.
+ * A strict deserialization function and a null value means we skip
+ * calling the combine function for this state. We assume that this
+ * would be a waste of time and effort anyway so just skip it.
+ */
+ if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0])
+ continue;
+ else
+ {
+ FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo;
+ MemoryContext oldContext;
+
+ dsinfo->arg[0] = slot->tts_values[0];
+ dsinfo->argnull[0] = slot->tts_isnull[0];
+
+ /*
+ * We run the deserialization functions in per-input-tuple
+ * memory context.
+ */
+ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
+
+ fcinfo->arg[1] = FunctionCallInvoke(dsinfo);
+ fcinfo->argnull[1] = dsinfo->isnull;
+
+ MemoryContextSwitchTo(oldContext);
+ }
+ }
+ else
+ {
+ fcinfo->arg[1] = slot->tts_values[0];
+ fcinfo->argnull[1] = slot->tts_isnull[0];
+ }
advance_combine_function(aggstate, pertrans, pergroupstate);
}
@@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate,
MemoryContextSwitchTo(oldContext);
}
+/*
+ * Compute the final value of one partial aggregate.
+ *
+ * The serialization function will be run, and the result delivered, in the
+ * output-tuple context; caller's CurrentMemoryContext does not matter.
+ */
+static void
+finalize_partialaggregate(AggState *aggstate,
+ AggStatePerAgg peragg,
+ AggStatePerGroup pergroupstate,
+ Datum *resultVal, bool *resultIsNull)
+{
+ AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno];
+ MemoryContext oldContext;
+
+ oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
+
+ /*
+ * serialfn_oid will be set if we must serialize the input state
+ * before calling the combine function on the state.
+ */
+ if (OidIsValid(pertrans->serialfn_oid))
+ {
+ /* Don't call a strict serialization function with NULL input. */
+ if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull)
+ {
+ *resultVal = (Datum) 0;
+ *resultIsNull = true;
+ }
+ else
+ {
+ FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo;
+ fcinfo->arg[0] = pergroupstate->transValue;
+ fcinfo->argnull[0] = pergroupstate->transValueIsNull;
+
+ *resultVal = FunctionCallInvoke(fcinfo);
+ *resultIsNull = fcinfo->isnull;
+ }
+ }
+ else
+ {
+ *resultVal = pergroupstate->transValue;
+ *resultIsNull = pergroupstate->transValueIsNull;
+ }
+
+ /* If result is pass-by-ref, make sure it is in the right context. */
+ if (!peragg->resulttypeByVal && !*resultIsNull &&
+ !MemoryContextContains(CurrentMemoryContext,
+ DatumGetPointer(*resultVal)))
+ *resultVal = datumCopy(*resultVal,
+ peragg->resulttypeByVal,
+ peragg->resulttypeLen);
+
+ MemoryContextSwitchTo(oldContext);
+}
/*
* Prepare to finalize and project based on the specified representative tuple
@@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate,
finalize_aggregate(aggstate, peragg, pergroupstate,
&aggvalues[aggno], &aggnulls[aggno]);
else
- {
- aggvalues[aggno] = pergroupstate->transValue;
- aggnulls[aggno] = pergroupstate->transValueIsNull;
- }
+ finalize_partialaggregate(aggstate, peragg, pergroupstate,
+ &aggvalues[aggno], &aggnulls[aggno]);
}
}
@@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->agg_done = false;
aggstate->combineStates = node->combineStates;
aggstate->finalizeAggs = node->finalizeAggs;
+ aggstate->serialStates = node->serialStates;
aggstate->input_done = false;
aggstate->pergroup = NULL;
aggstate->grp_firstTuple = NULL;
@@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
AclResult aclresult;
Oid transfn_oid,
finalfn_oid;
+ Oid serialtype_oid,
+ serialfn_oid,
+ deserialfn_oid;
Expr *finalfnexpr;
Oid aggtranstype;
Datum textInitVal;
@@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
else
peragg->finalfn_oid = finalfn_oid = InvalidOid;
+ serialtype_oid = InvalidOid;
+ serialfn_oid = InvalidOid;
+ deserialfn_oid = InvalidOid;
+
+ /*
+ * Determine if we require serialization or deserialization of the
+ * aggregate states. This is only required if the aggregate state is
+ * internal.
+ */
+ if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID)
+ {
+ /*
+ * The planner should only have generated an agg node with
+ * serialStates if every aggregate with an INTERNAL state has a
+ * serialization type, serialization function and deserialization
+ * function. Let's ensure it didn't mess that up.
+ */
+ if (!OidIsValid(aggform->aggserialtype))
+ elog(ERROR, "serialtype not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggserialfn))
+ elog(ERROR, "serialfunc not set during serialStates aggregation step");
+
+ if (!OidIsValid(aggform->aggdeserialfn))
+ elog(ERROR, "deserialfunc not set during serialStates aggregation step");
+
+ /* serialization func only required when not finalizing aggs */
+ if (!aggstate->finalizeAggs)
+ {
+ serialfn_oid = aggform->aggserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+
+ /* deserialization func only required when combining states */
+ if (aggstate->combineStates)
+ {
+ deserialfn_oid = aggform->aggdeserialfn;
+ serialtype_oid = aggform->aggserialtype;
+ }
+ }
+
/* Check that aggregate owner has permission to call component fns */
{
HeapTuple procTuple;
@@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(finalfn_oid));
InvokeFunctionExecuteHook(finalfn_oid);
}
+ if (OidIsValid(serialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(serialfn_oid));
+ InvokeFunctionExecuteHook(serialfn_oid);
+ }
+ if (OidIsValid(deserialfn_oid))
+ {
+ aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner,
+ ACL_EXECUTE);
+ if (aclresult != ACLCHECK_OK)
+ aclcheck_error(aclresult, ACL_KIND_PROC,
+ get_func_name(deserialfn_oid));
+ InvokeFunctionExecuteHook(deserialfn_oid);
+ }
}
/*
@@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
*/
existing_transno = find_compatible_pertrans(aggstate, aggref,
transfn_oid, aggtranstype,
- initValue, initValueIsNull,
+ serialfn_oid, deserialfn_oid,
+ initValue, initValueIsNull,
same_input_transnos);
if (existing_transno != -1)
{
@@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
pertrans = &pertransstates[++transno];
build_pertrans_for_aggref(pertrans, aggstate, estate,
aggref, transfn_oid, aggtranstype,
- initValue, initValueIsNull,
- inputTypes, numArguments);
+ serialtype_oid, serialfn_oid,
+ deserialfn_oid, initValue,
+ initValueIsNull, inputTypes,
+ numArguments);
peragg->transno = transno;
}
ReleaseSysCache(aggTuple);
@@ -2752,11 +2938,14 @@ static void
build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
Aggref *aggref,
- Oid aggtransfn, Oid aggtranstype,
+ Oid aggtransfn, Oid aggtranstype, Oid aggserialtype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
+ Expr *serialfnexpr = NULL;
+ Expr *deserialfnexpr = NULL;
ListCell *lc;
int numInputs;
int numDirectArgs;
@@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->aggref = aggref;
pertrans->aggCollation = aggref->inputcollid;
pertrans->transfn_oid = aggtransfn;
+ pertrans->serialfn_oid = aggserialfn;
+ pertrans->deserialfn_oid = aggdeserialfn;
pertrans->initValue = initValue;
pertrans->initValueIsNull = initValueIsNull;
@@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
2,
pertrans->aggCollation,
(void *) aggstate, NULL);
+
+ /*
+ * Ensure that a combine function to combine INTERNAL states is not
+ * strict. This should have been checked during CREATE AGGREGATE, but
+ * the strict property could have been changed since then.
+ */
+ if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function for aggregate %u must to be declared as strict",
+ aggref->aggfnoid)));
}
else
{
@@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
&pertrans->transtypeLen,
&pertrans->transtypeByVal);
+ if (OidIsValid(aggserialfn))
+ {
+ build_aggregate_serialfn_expr(aggtranstype,
+ aggserialtype,
+ aggref->inputcollid,
+ aggserialfn,
+ &serialfnexpr);
+ fmgr_info(aggserialfn, &pertrans->serialfn);
+ fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn);
+
+ InitFunctionCallInfoData(pertrans->serialfn_fcinfo,
+ &pertrans->serialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+ }
+
+ if (OidIsValid(aggdeserialfn))
+ {
+ build_aggregate_serialfn_expr(aggserialtype,
+ aggtranstype,
+ aggref->inputcollid,
+ aggdeserialfn,
+ &deserialfnexpr);
+ fmgr_info(aggdeserialfn, &pertrans->deserialfn);
+ fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn);
+
+ InitFunctionCallInfoData(pertrans->deserialfn_fcinfo,
+ &pertrans->deserialfn,
+ 1,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
+
+ }
+
/*
* Get a tupledesc corresponding to the aggregated inputs (including sort
* expressions) of the agg.
@@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate,
static int
find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
+ Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
List *transnos)
{
@@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
aggtranstype != pertrans->aggtranstype)
continue;
+ /*
+ * The serialization and deserialization functions must match, if
+ * present, as we're unable to share the trans state for aggregates
+ * which will serialize or deserialize into different formats. Remember
+ * that these will be InvalidOid if they're not required for this agg
+ * node.
+ */
+ if (aggserialfn != pertrans->serialfn_oid ||
+ aggdeserialfn != pertrans->deserialfn_oid)
+ continue;
+
/* Check that the initial condition matches, too. */
if (initValueIsNull && pertrans->initValueIsNull)
return transno;