diff options
Diffstat (limited to 'src/backend/executor/nodeAgg.c')
-rw-r--r-- | src/backend/executor/nodeAgg.c | 273 |
1 files changed, 261 insertions, 12 deletions
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 03aa20f61e0..aba54195a30 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -44,6 +44,12 @@ * incorrect. Instead a new state should be created in the correct aggregate * memory context and the 2nd state should be copied over. * + * The 'serialStates' option can be used to allow multi-stage aggregation + * for aggregates with an INTERNAL state type. When this mode is disabled + * only a pointer to the INTERNAL aggregate states are passed around the + * executor. When enabled, INTERNAL states are serialized and deserialized + * as required; this is useful when data must be passed between processes. + * * If a normal aggregate call specifies DISTINCT or ORDER BY, we sort the * input tuples and eliminate duplicates (if required) before performing * the above-depicted process. (However, we don't do that for ordered-set @@ -232,6 +238,12 @@ typedef struct AggStatePerTransData /* Oid of the state transition or combine function */ Oid transfn_oid; + /* Oid of the serialization function or InvalidOid */ + Oid serialfn_oid; + + /* Oid of the deserialization function or InvalidOid */ + Oid deserialfn_oid; + /* Oid of state value's datatype */ Oid aggtranstype; @@ -246,6 +258,12 @@ typedef struct AggStatePerTransData */ FmgrInfo transfn; + /* fmgr lookup data for serialization function */ + FmgrInfo serialfn; + + /* fmgr lookup data for deserialization function */ + FmgrInfo deserialfn; + /* Input collation derived for aggregate */ Oid aggCollation; @@ -326,6 +344,11 @@ typedef struct AggStatePerTransData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; + + /* Likewise for serialization and deserialization functions */ + FunctionCallInfoData serialfn_fcinfo; + + FunctionCallInfoData deserialfn_fcinfo; } AggStatePerTransData; /* @@ -467,6 +490,10 @@ static void finalize_aggregate(AggState *aggstate, AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); +static void finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet); @@ -487,12 +514,15 @@ static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggsate, EState *estate, Aggref *aggref, Oid aggtransfn, Oid aggtranstype, - Datum initValue, bool initValueIsNull, - Oid *inputTypes, int numArguments); + Oid aggserialtype, Oid aggserialfn, + Oid aggdeserialfn, Datum initValue, + bool initValueIsNull, Oid *inputTypes, + int numArguments); static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, int lastaggno, List **same_input_transnos); static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos); @@ -944,8 +974,45 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) slot = ExecProject(pertrans->evalproj, NULL); Assert(slot->tts_nvalid >= 1); - fcinfo->arg[1] = slot->tts_values[0]; - fcinfo->argnull[1] = slot->tts_isnull[0]; + /* + * deserialfn_oid will be set if we must deserialize the input state + * before calling the combine function + */ + if (OidIsValid(pertrans->deserialfn_oid)) + { + /* + * Don't call a strict deserialization function with NULL input. + * A strict deserialization function and a null value means we skip + * calling the combine function for this state. We assume that this + * would be a waste of time and effort anyway so just skip it. + */ + if (pertrans->deserialfn.fn_strict && slot->tts_isnull[0]) + continue; + else + { + FunctionCallInfo dsinfo = &pertrans->deserialfn_fcinfo; + MemoryContext oldContext; + + dsinfo->arg[0] = slot->tts_values[0]; + dsinfo->argnull[0] = slot->tts_isnull[0]; + + /* + * We run the deserialization functions in per-input-tuple + * memory context. + */ + oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); + + fcinfo->arg[1] = FunctionCallInvoke(dsinfo); + fcinfo->argnull[1] = dsinfo->isnull; + + MemoryContextSwitchTo(oldContext); + } + } + else + { + fcinfo->arg[1] = slot->tts_values[0]; + fcinfo->argnull[1] = slot->tts_isnull[0]; + } advance_combine_function(aggstate, pertrans, pergroupstate); } @@ -1344,6 +1411,61 @@ finalize_aggregate(AggState *aggstate, MemoryContextSwitchTo(oldContext); } +/* + * Compute the final value of one partial aggregate. + * + * The serialization function will be run, and the result delivered, in the + * output-tuple context; caller's CurrentMemoryContext does not matter. + */ +static void +finalize_partialaggregate(AggState *aggstate, + AggStatePerAgg peragg, + AggStatePerGroup pergroupstate, + Datum *resultVal, bool *resultIsNull) +{ + AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno]; + MemoryContext oldContext; + + oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); + + /* + * serialfn_oid will be set if we must serialize the input state + * before calling the combine function on the state. + */ + if (OidIsValid(pertrans->serialfn_oid)) + { + /* Don't call a strict serialization function with NULL input. */ + if (pertrans->serialfn.fn_strict && pergroupstate->transValueIsNull) + { + *resultVal = (Datum) 0; + *resultIsNull = true; + } + else + { + FunctionCallInfo fcinfo = &pertrans->serialfn_fcinfo; + fcinfo->arg[0] = pergroupstate->transValue; + fcinfo->argnull[0] = pergroupstate->transValueIsNull; + + *resultVal = FunctionCallInvoke(fcinfo); + *resultIsNull = fcinfo->isnull; + } + } + else + { + *resultVal = pergroupstate->transValue; + *resultIsNull = pergroupstate->transValueIsNull; + } + + /* If result is pass-by-ref, make sure it is in the right context. */ + if (!peragg->resulttypeByVal && !*resultIsNull && + !MemoryContextContains(CurrentMemoryContext, + DatumGetPointer(*resultVal))) + *resultVal = datumCopy(*resultVal, + peragg->resulttypeByVal, + peragg->resulttypeLen); + + MemoryContextSwitchTo(oldContext); +} /* * Prepare to finalize and project based on the specified representative tuple @@ -1455,10 +1577,8 @@ finalize_aggregates(AggState *aggstate, finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); else - { - aggvalues[aggno] = pergroupstate->transValue; - aggnulls[aggno] = pergroupstate->transValueIsNull; - } + finalize_partialaggregate(aggstate, peragg, pergroupstate, + &aggvalues[aggno], &aggnulls[aggno]); } } @@ -2238,6 +2358,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->agg_done = false; aggstate->combineStates = node->combineStates; aggstate->finalizeAggs = node->finalizeAggs; + aggstate->serialStates = node->serialStates; aggstate->input_done = false; aggstate->pergroup = NULL; aggstate->grp_firstTuple = NULL; @@ -2546,6 +2667,9 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) AclResult aclresult; Oid transfn_oid, finalfn_oid; + Oid serialtype_oid, + serialfn_oid, + deserialfn_oid; Expr *finalfnexpr; Oid aggtranstype; Datum textInitVal; @@ -2610,6 +2734,47 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) else peragg->finalfn_oid = finalfn_oid = InvalidOid; + serialtype_oid = InvalidOid; + serialfn_oid = InvalidOid; + deserialfn_oid = InvalidOid; + + /* + * Determine if we require serialization or deserialization of the + * aggregate states. This is only required if the aggregate state is + * internal. + */ + if (aggstate->serialStates && aggform->aggtranstype == INTERNALOID) + { + /* + * The planner should only have generated an agg node with + * serialStates if every aggregate with an INTERNAL state has a + * serialization type, serialization function and deserialization + * function. Let's ensure it didn't mess that up. + */ + if (!OidIsValid(aggform->aggserialtype)) + elog(ERROR, "serialtype not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggserialfn)) + elog(ERROR, "serialfunc not set during serialStates aggregation step"); + + if (!OidIsValid(aggform->aggdeserialfn)) + elog(ERROR, "deserialfunc not set during serialStates aggregation step"); + + /* serialization func only required when not finalizing aggs */ + if (!aggstate->finalizeAggs) + { + serialfn_oid = aggform->aggserialfn; + serialtype_oid = aggform->aggserialtype; + } + + /* deserialization func only required when combining states */ + if (aggstate->combineStates) + { + deserialfn_oid = aggform->aggdeserialfn; + serialtype_oid = aggform->aggserialtype; + } + } + /* Check that aggregate owner has permission to call component fns */ { HeapTuple procTuple; @@ -2638,6 +2803,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(finalfn_oid)); InvokeFunctionExecuteHook(finalfn_oid); } + if (OidIsValid(serialfn_oid)) + { + aclresult = pg_proc_aclcheck(serialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(serialfn_oid)); + InvokeFunctionExecuteHook(serialfn_oid); + } + if (OidIsValid(deserialfn_oid)) + { + aclresult = pg_proc_aclcheck(deserialfn_oid, aggOwner, + ACL_EXECUTE); + if (aclresult != ACLCHECK_OK) + aclcheck_error(aclresult, ACL_KIND_PROC, + get_func_name(deserialfn_oid)); + InvokeFunctionExecuteHook(deserialfn_oid); + } } /* @@ -2707,7 +2890,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ existing_transno = find_compatible_pertrans(aggstate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, + serialfn_oid, deserialfn_oid, + initValue, initValueIsNull, same_input_transnos); if (existing_transno != -1) { @@ -2723,8 +2907,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) pertrans = &pertransstates[++transno]; build_pertrans_for_aggref(pertrans, aggstate, estate, aggref, transfn_oid, aggtranstype, - initValue, initValueIsNull, - inputTypes, numArguments); + serialtype_oid, serialfn_oid, + deserialfn_oid, initValue, + initValueIsNull, inputTypes, + numArguments); peragg->transno = transno; } ReleaseSysCache(aggTuple); @@ -2752,11 +2938,14 @@ static void build_pertrans_for_aggref(AggStatePerTrans pertrans, AggState *aggstate, EState *estate, Aggref *aggref, - Oid aggtransfn, Oid aggtranstype, + Oid aggtransfn, Oid aggtranstype, Oid aggserialtype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, Oid *inputTypes, int numArguments) { int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *serialfnexpr = NULL; + Expr *deserialfnexpr = NULL; ListCell *lc; int numInputs; int numDirectArgs; @@ -2770,6 +2959,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, pertrans->aggref = aggref; pertrans->aggCollation = aggref->inputcollid; pertrans->transfn_oid = aggtransfn; + pertrans->serialfn_oid = aggserialfn; + pertrans->deserialfn_oid = aggdeserialfn; pertrans->initValue = initValue; pertrans->initValueIsNull = initValueIsNull; @@ -2809,6 +3000,17 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, 2, pertrans->aggCollation, (void *) aggstate, NULL); + + /* + * Ensure that a combine function to combine INTERNAL states is not + * strict. This should have been checked during CREATE AGGREGATE, but + * the strict property could have been changed since then. + */ + if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("combine function for aggregate %u must to be declared as strict", + aggref->aggfnoid))); } else { @@ -2861,6 +3063,41 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, &pertrans->transtypeLen, &pertrans->transtypeByVal); + if (OidIsValid(aggserialfn)) + { + build_aggregate_serialfn_expr(aggtranstype, + aggserialtype, + aggref->inputcollid, + aggserialfn, + &serialfnexpr); + fmgr_info(aggserialfn, &pertrans->serialfn); + fmgr_info_set_expr((Node *) serialfnexpr, &pertrans->serialfn); + + InitFunctionCallInfoData(pertrans->serialfn_fcinfo, + &pertrans->serialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + } + + if (OidIsValid(aggdeserialfn)) + { + build_aggregate_serialfn_expr(aggserialtype, + aggtranstype, + aggref->inputcollid, + aggdeserialfn, + &deserialfnexpr); + fmgr_info(aggdeserialfn, &pertrans->deserialfn); + fmgr_info_set_expr((Node *) deserialfnexpr, &pertrans->deserialfn); + + InitFunctionCallInfoData(pertrans->deserialfn_fcinfo, + &pertrans->deserialfn, + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + } + /* * Get a tupledesc corresponding to the aggregated inputs (including sort * expressions) of the agg. @@ -3107,6 +3344,7 @@ find_compatible_peragg(Aggref *newagg, AggState *aggstate, static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, Oid aggtransfn, Oid aggtranstype, + Oid aggserialfn, Oid aggdeserialfn, Datum initValue, bool initValueIsNull, List *transnos) { @@ -3125,6 +3363,17 @@ find_compatible_pertrans(AggState *aggstate, Aggref *newagg, aggtranstype != pertrans->aggtranstype) continue; + /* + * The serialization and deserialization functions must match, if + * present, as we're unable to share the trans state for aggregates + * which will serialize or deserialize into different formats. Remember + * that these will be InvalidOid if they're not required for this agg + * node. + */ + if (aggserialfn != pertrans->serialfn_oid || + aggdeserialfn != pertrans->deserialfn_oid) + continue; + /* Check that the initial condition matches, too. */ if (initValueIsNull && pertrans->initValueIsNull) return transno; |