aboutsummaryrefslogtreecommitdiff
path: root/src/backend/executor/nodeAgg.c
diff options
context:
space:
mode:
authorDavid Rowley <drowley@postgresql.org>2021-07-04 18:47:31 +1200
committerDavid Rowley <drowley@postgresql.org>2021-07-04 18:47:31 +1200
commit63b1af94375cc2be06a5d6a932db24cd8e9f45e9 (patch)
treeeb3a09111a707b132221674895b9190ee06bab4c /src/backend/executor/nodeAgg.c
parent792259591c0fc19c42247fc7668b1064d1e850d4 (diff)
downloadpostgresql-63b1af94375cc2be06a5d6a932db24cd8e9f45e9.tar.gz
postgresql-63b1af94375cc2be06a5d6a932db24cd8e9f45e9.zip
Cleanup some aggregate code in the executor
Here we alter the code that calls build_pertrans_for_aggref() so that the function no longer needs to special-case whether it's dealing with an aggtransfn or an aggcombinefn. This allows us to reuse the build_aggregate_transfn_expr() function and just get rid of the build_aggregate_combinefn_expr() completely. All of the special case code that was in build_pertrans_for_aggref() has been moved up to the calling functions. This saves about a dozen lines of code in nodeAgg.c and a few dozen more in parse_agg.c Also, rename a few variables in nodeAgg.c to try to make it more clear that we're working with either a aggtransfn or an aggcombinefn. Some of the old names would have you believe that we were always working with an aggtransfn. Discussion: https://postgr.es/m/CAApHDvptMQ9FmF0D67zC_w88yVnoNVR2+kkOQGUrCmdxWxLULQ@mail.gmail.com
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r--src/backend/executor/nodeAgg.c234
1 files changed, 113 insertions, 121 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 8440a76fbdc..914b02ceee4 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -461,10 +461,11 @@ static void hashagg_tapeinfo_release(HashTapeInfo *tapeinfo, int tapenum);
static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
- Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
- Oid aggserialfn, Oid aggdeserialfn,
- Datum initValue, bool initValueIsNull,
- Oid *inputTypes, int numArguments);
+ Aggref *aggref, Oid transfn_oid,
+ Oid aggtranstype, Oid aggserialfn,
+ Oid aggdeserialfn, Datum initValue,
+ bool initValueIsNull, Oid *inputTypes,
+ int numArguments);
/*
@@ -3724,8 +3725,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
Aggref *aggref = lfirst(l);
AggStatePerAgg peragg;
AggStatePerTrans pertrans;
- Oid inputTypes[FUNC_MAX_ARGS];
- int numArguments;
+ Oid aggTransFnInputTypes[FUNC_MAX_ARGS];
+ int numAggTransFnArgs;
int numDirectArgs;
HeapTuple aggTuple;
Form_pg_aggregate aggform;
@@ -3859,14 +3860,15 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* could be different from the agg's declared input types, when the
* agg accepts ANY or a polymorphic type.
*/
- numArguments = get_aggregate_argtypes(aggref, inputTypes);
+ numAggTransFnArgs = get_aggregate_argtypes(aggref,
+ aggTransFnInputTypes);
/* Count the "direct" arguments, if any */
numDirectArgs = list_length(aggref->aggdirectargs);
/* Detect how many arguments to pass to the finalfn */
if (aggform->aggfinalextra)
- peragg->numFinalArgs = numArguments + 1;
+ peragg->numFinalArgs = numAggTransFnArgs + 1;
else
peragg->numFinalArgs = numDirectArgs + 1;
@@ -3880,7 +3882,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
*/
if (OidIsValid(finalfn_oid))
{
- build_aggregate_finalfn_expr(inputTypes,
+ build_aggregate_finalfn_expr(aggTransFnInputTypes,
peragg->numFinalArgs,
aggtranstype,
aggref->aggtype,
@@ -3911,7 +3913,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
/*
* If this aggregation is performing state combines, then instead
* of using the transition function, we'll use the combine
- * function
+ * function.
*/
if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
{
@@ -3924,8 +3926,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
else
transfn_oid = aggform->aggtransfn;
- aclresult = pg_proc_aclcheck(transfn_oid, aggOwner,
- ACL_EXECUTE);
+ aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, ACL_EXECUTE);
if (aclresult != ACLCHECK_OK)
aclcheck_error(aclresult, OBJECT_FUNCTION,
get_func_name(transfn_oid));
@@ -3943,11 +3944,72 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
else
initValue = GetAggInitVal(textInitVal, aggtranstype);
- build_pertrans_for_aggref(pertrans, aggstate, estate,
- aggref, transfn_oid, aggtranstype,
- serialfn_oid, deserialfn_oid,
- initValue, initValueIsNull,
- inputTypes, numArguments);
+ if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
+ {
+ Oid combineFnInputTypes[] = {aggtranstype,
+ aggtranstype};
+
+ /*
+ * When combining there's only one input, the to-be-combined
+ * transition value. The transition value is not counted
+ * here.
+ */
+ pertrans->numTransInputs = 1;
+
+ /* aggcombinefn always has two arguments of aggtranstype */
+ build_pertrans_for_aggref(pertrans, aggstate, estate,
+ aggref, transfn_oid, aggtranstype,
+ serialfn_oid, deserialfn_oid,
+ initValue, initValueIsNull,
+ combineFnInputTypes, 2);
+
+ /*
+ * Ensure that a combine function to combine INTERNAL states
+ * is not strict. This should have been checked during CREATE
+ * AGGREGATE, but the strict property could have been changed
+ * since then.
+ */
+ if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("combine function with transition type %s must not be declared STRICT",
+ format_type_be(aggtranstype))));
+ }
+ else
+ {
+ /* Detect how many arguments to pass to the transfn */
+ if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+ pertrans->numTransInputs = list_length(aggref->args);
+ else
+ pertrans->numTransInputs = numAggTransFnArgs;
+
+ build_pertrans_for_aggref(pertrans, aggstate, estate,
+ aggref, transfn_oid, aggtranstype,
+ serialfn_oid, deserialfn_oid,
+ initValue, initValueIsNull,
+ aggTransFnInputTypes,
+ numAggTransFnArgs);
+
+ /*
+ * If the transfn is strict and the initval is NULL, make sure
+ * input type and transtype are the same (or at least
+ * binary-compatible), so that it's OK to use the first
+ * aggregated input value as the initial transValue. This
+ * should have been checked at agg definition time, but we
+ * must check again in case the transfn's strictness property
+ * has been changed.
+ */
+ if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+ {
+ if (numAggTransFnArgs <= numDirectArgs ||
+ !IsBinaryCoercible(aggTransFnInputTypes[numDirectArgs],
+ aggtranstype))
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+ errmsg("aggregate %u needs to have compatible input type and transition type",
+ aggref->aggfnoid)));
+ }
+ }
}
else
pertrans->aggshared = true;
@@ -4039,20 +4101,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* Build the state needed to calculate a state value for an aggregate.
*
* This initializes all the fields in 'pertrans'. 'aggref' is the aggregate
- * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest
+ * to initialize the state for. 'transfn_oid', 'aggtranstype', and the rest
* of the arguments could be calculated from 'aggref', but the caller has
* calculated them already, so might as well pass them.
+ *
+ * 'transfn_oid' may be either the Oid of the aggtransfn or the aggcombinefn.
*/
static void
build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
Aggref *aggref,
- Oid aggtransfn, Oid aggtranstype,
+ Oid transfn_oid, Oid aggtranstype,
Oid aggserialfn, Oid aggdeserialfn,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
+ Expr *transfnexpr;
+ int numTransArgs;
Expr *serialfnexpr = NULL;
Expr *deserialfnexpr = NULL;
ListCell *lc;
@@ -4067,7 +4133,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->aggref = aggref;
pertrans->aggshared = false;
pertrans->aggCollation = aggref->inputcollid;
- pertrans->transfn_oid = aggtransfn;
+ pertrans->transfn_oid = transfn_oid;
pertrans->serialfn_oid = aggserialfn;
pertrans->deserialfn_oid = aggdeserialfn;
pertrans->initValue = initValue;
@@ -4081,111 +4147,34 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->aggtranstype = aggtranstype;
+ /* account for the current transition state */
+ numTransArgs = pertrans->numTransInputs + 1;
+
/*
- * When combining states, we have no use at all for the aggregate
- * function's transfn. Instead we use the combinefn. In this case, the
- * transfn and transfn_oid fields of pertrans refer to the combine
- * function rather than the transition function.
+ * Set up infrastructure for calling the transfn. Note that invtrans is
+ * not needed here.
*/
- if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
- {
- Expr *combinefnexpr;
- size_t numTransArgs;
-
- /*
- * When combining there's only one input, the to-be-combined added
- * transition value from below (this node's transition value is
- * counted separately).
- */
- pertrans->numTransInputs = 1;
-
- /* account for the current transition state */
- numTransArgs = pertrans->numTransInputs + 1;
-
- build_aggregate_combinefn_expr(aggtranstype,
- aggref->inputcollid,
- aggtransfn,
- &combinefnexpr);
- fmgr_info(aggtransfn, &pertrans->transfn);
- fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn);
-
- pertrans->transfn_fcinfo =
- (FunctionCallInfo) palloc(SizeForFunctionCallInfo(2));
- InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
- &pertrans->transfn,
- numTransArgs,
- pertrans->aggCollation,
- (void *) aggstate, NULL);
-
- /*
- * Ensure that a combine function to combine INTERNAL states is not
- * strict. This should have been checked during CREATE AGGREGATE, but
- * the strict property could have been changed since then.
- */
- if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
- ereport(ERROR,
- (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
- errmsg("combine function with transition type %s must not be declared STRICT",
- format_type_be(aggtranstype))));
- }
- else
- {
- Expr *transfnexpr;
- size_t numTransArgs;
-
- /* Detect how many arguments to pass to the transfn */
- if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
- pertrans->numTransInputs = numInputs;
- else
- pertrans->numTransInputs = numArguments;
+ build_aggregate_transfn_expr(inputTypes,
+ numArguments,
+ numDirectArgs,
+ aggref->aggvariadic,
+ aggtranstype,
+ aggref->inputcollid,
+ transfn_oid,
+ InvalidOid,
+ &transfnexpr,
+ NULL);
- /* account for the current transition state */
- numTransArgs = pertrans->numTransInputs + 1;
+ fmgr_info(transfn_oid, &pertrans->transfn);
+ fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
- /*
- * Set up infrastructure for calling the transfn. Note that
- * invtransfn is not needed here.
- */
- build_aggregate_transfn_expr(inputTypes,
- numArguments,
- numDirectArgs,
- aggref->aggvariadic,
- aggtranstype,
- aggref->inputcollid,
- aggtransfn,
- InvalidOid,
- &transfnexpr,
- NULL);
- fmgr_info(aggtransfn, &pertrans->transfn);
- fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
-
- pertrans->transfn_fcinfo =
- (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
- InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
- &pertrans->transfn,
- numTransArgs,
- pertrans->aggCollation,
- (void *) aggstate, NULL);
-
- /*
- * If the transfn is strict and the initval is NULL, make sure input
- * type and transtype are the same (or at least binary-compatible), so
- * that it's OK to use the first aggregated input value as the initial
- * transValue. This should have been checked at agg definition time,
- * but we must check again in case the transfn's strictness property
- * has been changed.
- */
- if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
- {
- if (numArguments <= numDirectArgs ||
- !IsBinaryCoercible(inputTypes[numDirectArgs],
- aggtranstype))
- ereport(ERROR,
- (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
- errmsg("aggregate %u needs to have compatible input type and transition type",
- aggref->aggfnoid)));
- }
- }
+ pertrans->transfn_fcinfo =
+ (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
+ InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
+ &pertrans->transfn,
+ numTransArgs,
+ pertrans->aggCollation,
+ (void *) aggstate, NULL);
/* get info about the state value's datatype */
get_typlenbyval(aggtranstype,
@@ -4276,6 +4265,9 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
*/
Assert(aggstate->aggstrategy != AGG_HASHED && aggstate->aggstrategy != AGG_MIXED);
+ /* ORDER BY aggregates are not supported with partial aggregation */
+ Assert(!DO_AGGSPLIT_COMBINE(aggstate->aggsplit));
+
/* If we have only one input, we need its len/byval info. */
if (numInputs == 1)
{