From 804163bc25e979fcd91b02e58fa2d1c6b587cc65 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Tue, 4 Aug 2015 17:53:10 +0300 Subject: [PATCH] Share transition state between different aggregates when possible. If there are two different aggregates in the query with same inputs, and the aggregates have the same initial condition and transition function, only calculate the state value once, and only call the final functions separately. For example, AVG(x) and SUM(x) aggregates have the same transition function, which accumulates the sum and number of input tuples. For a query like "SELECT AVG(x), SUM(x) FROM x", we can therefore accumulate the state function only once, which gives a nice speedup. David Rowley, reviewed and edited by me. --- src/backend/executor/execQual.c | 22 +- src/backend/executor/nodeAgg.c | 1094 ++++++++++++++-------- src/backend/executor/nodeWindowAgg.c | 31 +- src/backend/parser/parse_agg.c | 75 +- src/include/nodes/execnodes.h | 8 +- src/include/parser/parse_agg.h | 14 +- src/test/regress/expected/aggregates.out | 204 ++++ src/test/regress/sql/aggregates.sql | 165 ++++ 8 files changed, 1144 insertions(+), 469 deletions(-) diff --git a/src/backend/executor/execQual.c b/src/backend/executor/execQual.c index 16bc8fa5f6..29f058ce5c 100644 --- a/src/backend/executor/execQual.c +++ b/src/backend/executor/execQual.c @@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent) break; case T_Aggref: { - Aggref *aggref = (Aggref *) node; AggrefExprState *astate = makeNode(AggrefExprState); astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref; if (parent && IsA(parent, AggState)) { AggState *aggstate = (AggState *) parent; - int naggs; aggstate->aggs = lcons(astate, aggstate->aggs); - naggs = ++aggstate->numaggs; - - astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs, - parent); - astate->args = (List *) ExecInitExpr((Expr *) aggref->args, - parent); - astate->aggfilter = ExecInitExpr(aggref->aggfilter, - parent); - - /* - * Complain if the aggregate's arguments contain any - * aggregates; nested agg functions are semantically - * nonsensical. (This should have been caught earlier, - * but we defend against it here anyway.) - */ - if (naggs != aggstate->numaggs) - ereport(ERROR, - (errcode(ERRCODE_GROUPING_ERROR), - errmsg("aggregate function calls cannot be nested"))); + aggstate->numaggs++; } else { diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 2bf48c54e3..2e3685557b 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -152,17 +152,28 @@ /* - * AggStatePerAggData - per-aggregate working state for the Agg scan + * AggStatePerTransData - per aggregate state value information + * + * Working state for updating the aggregate's state value, by calling the + * transition function with an input row. This struct does not store the + * information needed to produce the final aggregate result from the transition + * state, that's stored in AggStatePerAggData instead. This separation allows + * multiple aggregate results to be produced from a single state value. */ -typedef struct AggStatePerAggData +typedef struct AggStatePerTransData { /* * These values are set up during ExecInitAgg() and do not change * thereafter: */ - /* Links to Aggref expr and state nodes this working state is for */ - AggrefExprState *aggrefstate; + /* + * Link to an Aggref expr this state value is for. + * + * There can be multiple Aggref's sharing the same state value, as long as + * the inputs and transition function are identical. This points to the + * first one of them. + */ Aggref *aggref; /* @@ -186,25 +197,22 @@ typedef struct AggStatePerAggData */ int numTransInputs; - /* - * Number of arguments to pass to the finalfn. This is always at least 1 - * (the transition state value) plus any ordered-set direct args. If the - * finalfn wants extra args then we pass nulls corresponding to the - * aggregated input columns. - */ - int numFinalArgs; - - /* Oids of transfer functions */ + /* Oid of the state transition function */ Oid transfn_oid; - Oid finalfn_oid; /* may be InvalidOid */ + + /* Oid of state value's datatype */ + Oid aggtranstype; + + /* ExprStates of the FILTER and argument expressions. */ + ExprState *aggfilter; /* state of FILTER expression, if any */ + List *args; /* states of aggregated-argument expressions */ + List *aggdirectargs; /* states of direct-argument expressions */ /* - * fmgr lookup data for transfer functions --- only valid when - * corresponding oid is not InvalidOid. Note in particular that fn_strict - * flags are kept here. + * fmgr lookup data for transition function. Note in particular that the + * fn_strict flag is kept here. */ FmgrInfo transfn; - FmgrInfo finalfn; /* Input collation derived for aggregate */ Oid aggCollation; @@ -236,17 +244,15 @@ typedef struct AggStatePerAggData bool initValueIsNull; /* - * We need the len and byval info for the agg's input, result, and - * transition data types in order to know how to copy/delete values. + * We need the len and byval info for the agg's input and transition data + * types in order to know how to copy/delete values. * * Note that the info for the input type is used only when handling * DISTINCT aggs with just one argument, so there is only one input type. */ int16 inputtypeLen, - resulttypeLen, transtypeLen; bool inputtypeByVal, - resulttypeByVal, transtypeByVal; /* @@ -288,6 +294,54 @@ typedef struct AggStatePerAggData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; +} AggStatePerTransData; + +/* + * AggStatePerAggData - per-aggregate information + * + * This contains the information needed to call the final function, to produce + * a final aggregate result from the state value. If there are multiple + * identical Aggrefs in the query, they can all share the same per-agg data. + * + * These values are set up during ExecInitAgg() and do not change thereafter. + */ +typedef struct AggStatePerAggData +{ + /* + * Link to an Aggref expr this state value is for. + * + * There can be multiple identical Aggref's sharing the same per-agg. This + * points to the first one of them. + */ + Aggref *aggref; + + /* index to the state value which this agg should use */ + int transno; + + /* Optional Oid of final function (may be InvalidOid) */ + Oid finalfn_oid; + + /* + * fmgr lookup data for final function --- only valid when finalfn_oid oid + * is not InvalidOid. + */ + FmgrInfo finalfn; + + /* + * Number of arguments to pass to the finalfn. This is always at least 1 + * (the transition state value) plus any ordered-set direct args. If the + * finalfn wants extra args then we pass nulls corresponding to the + * aggregated input columns. + */ + int numFinalArgs; + + /* + * We need the len and byval info for the agg's result data type in order + * to know how to copy/delete values. + */ + int16 resulttypeLen; + bool resulttypeByVal; + } AggStatePerAggData; /* @@ -358,25 +412,23 @@ typedef struct AggHashEntryData AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER]; } AggHashEntryData; - static void initialize_phase(AggState *aggstate, int newphase); static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, AggStatePerGroup pergroup, int numReset); static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, @@ -396,6 +448,17 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate); static void agg_fill_hash_table(AggState *aggstate); static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate); static Datum GetAggInitVal(Datum textInitVal, Oid transtype); +static void build_pertrans_for_aggref(AggStatePerTrans pertrans, + AggState *aggsate, EState *estate, + Aggref *aggref, Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + Oid *inputTypes, int numArguments); +static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, + int lastaggno, List **same_input_transnos); +static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + List *possible_matches); /* @@ -498,20 +561,20 @@ fetch_input_tuple(AggState *aggstate) * When called, CurrentMemoryContext should be the per-query context. */ static void -initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, +initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { /* * Start a fresh sort operation for each DISTINCT/ORDER BY aggregate. */ - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { /* * In case of rescan, maybe there could be an uncompleted sort * operation? Clean it up if so. */ - if (peraggstate->sortstates[aggstate->current_set]) - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); + if (pertrans->sortstates[aggstate->current_set]) + tuplesort_end(pertrans->sortstates[aggstate->current_set]); /* @@ -519,21 +582,21 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * otherwise sort the full tuple. (See comments for * process_ordered_aggregate_single.) */ - if (peraggstate->numInputs == 1) - peraggstate->sortstates[aggstate->current_set] = - tuplesort_begin_datum(peraggstate->evaldesc->attrs[0]->atttypid, - peraggstate->sortOperators[0], - peraggstate->sortCollations[0], - peraggstate->sortNullsFirst[0], + if (pertrans->numInputs == 1) + pertrans->sortstates[aggstate->current_set] = + tuplesort_begin_datum(pertrans->evaldesc->attrs[0]->atttypid, + pertrans->sortOperators[0], + pertrans->sortCollations[0], + pertrans->sortNullsFirst[0], work_mem, false); else - peraggstate->sortstates[aggstate->current_set] = - tuplesort_begin_heap(peraggstate->evaldesc, - peraggstate->numSortCols, - peraggstate->sortColIdx, - peraggstate->sortOperators, - peraggstate->sortCollations, - peraggstate->sortNullsFirst, + pertrans->sortstates[aggstate->current_set] = + tuplesort_begin_heap(pertrans->evaldesc, + pertrans->numSortCols, + pertrans->sortColIdx, + pertrans->sortOperators, + pertrans->sortCollations, + pertrans->sortNullsFirst, work_mem, false); } @@ -543,20 +606,20 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * Note that when the initial value is pass-by-ref, we must copy it (into * the aggcontext) since we will pfree the transValue later. */ - if (peraggstate->initValueIsNull) - pergroupstate->transValue = peraggstate->initValue; + if (pertrans->initValueIsNull) + pergroupstate->transValue = pertrans->initValue; else { MemoryContext oldContext; oldContext = MemoryContextSwitchTo( aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); - pergroupstate->transValue = datumCopy(peraggstate->initValue, - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pergroupstate->transValue = datumCopy(pertrans->initValue, + pertrans->transtypeByVal, + pertrans->transtypeLen); MemoryContextSwitchTo(oldContext); } - pergroupstate->transValueIsNull = peraggstate->initValueIsNull; + pergroupstate->transValueIsNull = pertrans->initValueIsNull; /* * If the initial value for the transition state doesn't exist in the @@ -565,11 +628,11 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * aggregates like max() and min().) The noTransValue flag signals that we * still need to do this. */ - pergroupstate->noTransValue = peraggstate->initValueIsNull; + pergroupstate->noTransValue = pertrans->initValueIsNull; } /* - * Initialize all aggregates for a new group of input values. + * Initialize all aggregate transition states for a new group of input values. * * If there are multiple grouping sets, we initialize only the first numReset * of them (the grouping sets are ordered so that the most specific one, which @@ -580,61 +643,61 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, */ static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, AggStatePerGroup pergroup, int numReset) { - int aggno; + int transno; int numGroupingSets = Max(aggstate->phase->numsets, 1); int setno = 0; + AggStatePerTrans transstates = aggstate->pertrans; if (numReset < 1) numReset = numGroupingSets; - for (aggno = 0; aggno < aggstate->numaggs; aggno++) + for (transno = 0; transno < aggstate->numtrans; transno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerTrans pertrans = &transstates[transno]; for (setno = 0; setno < numReset; setno++) { AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))]; + pergroupstate = &pergroup[transno + (setno * (aggstate->numtrans))]; aggstate->current_set = setno; - initialize_aggregate(aggstate, peraggstate, pergroupstate); + initialize_aggregate(aggstate, pertrans, pergroupstate); } } } /* * Given new input value(s), advance the transition function of one aggregate - * within one grouping set only (already set in aggstate->current_set) + * state within one grouping set only (already set in aggstate->current_set) * * The new values (and null flags) have been preloaded into argument positions - * 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again - * to pass to the transition function. We also expect that the static fields - * of the fcinfo are already initialized; that was done by ExecInitAgg(). + * 1 and up in pertrans->transfn_fcinfo, so that we needn't copy them again to + * pass to the transition function. We also expect that the static fields of + * the fcinfo are already initialized; that was done by ExecInitAgg(). * * It doesn't matter which memory context this is called in. */ static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; MemoryContext oldContext; Datum newVal; - if (peraggstate->transfn.fn_strict) + if (pertrans->transfn.fn_strict) { /* * For a strict transfn, nothing happens when there's a NULL input; we * just keep the prior transValue. */ - int numTransInputs = peraggstate->numTransInputs; + int numTransInputs = pertrans->numTransInputs; int i; for (i = 1; i <= numTransInputs; i++) @@ -656,8 +719,8 @@ advance_transition_function(AggState *aggstate, oldContext = MemoryContextSwitchTo( aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); pergroupstate->transValue = datumCopy(fcinfo->arg[1], - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pertrans->transtypeByVal, + pertrans->transtypeLen); pergroupstate->transValueIsNull = false; pergroupstate->noTransValue = false; MemoryContextSwitchTo(oldContext); @@ -678,8 +741,8 @@ advance_transition_function(AggState *aggstate, /* We run the transition functions in per-input-tuple memory context */ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; /* * OK to call the transition function @@ -690,22 +753,22 @@ advance_transition_function(AggState *aggstate, newVal = FunctionCallInvoke(fcinfo); - aggstate->curperagg = NULL; + aggstate->curpertrans = NULL; /* * If pass-by-ref datatype, must copy the new value into aggcontext and * pfree the prior transValue. But if transfn returned a pointer to its * first input, we don't need to do anything. */ - if (!peraggstate->transtypeByVal && + if (!pertrans->transtypeByVal && DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue)) { if (!fcinfo->isnull) { MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); newVal = datumCopy(newVal, - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pertrans->transtypeByVal, + pertrans->transtypeLen); } if (!pergroupstate->transValueIsNull) pfree(DatumGetPointer(pergroupstate->transValue)); @@ -718,26 +781,26 @@ advance_transition_function(AggState *aggstate, } /* - * Advance all the aggregates for one input tuple. The input tuple - * has been stored in tmpcontext->ecxt_outertuple, so that it is accessible - * to ExecEvalExpr. pergroup is the array of per-group structs to use - * (this might be in a hashtable entry). + * Advance each aggregate transition state for one input tuple. The input + * tuple has been stored in tmpcontext->ecxt_outertuple, so that it is + * accessible to ExecEvalExpr. pergroup is the array of per-group structs to + * use (this might be in a hashtable entry). * * When called, CurrentMemoryContext should be the per-query context. */ static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) { - int aggno; + int transno; int setno = 0; int numGroupingSets = Max(aggstate->phase->numsets, 1); - int numAggs = aggstate->numaggs; + int numTrans = aggstate->numtrans; - for (aggno = 0; aggno < numAggs; aggno++) + for (transno = 0; transno < numTrans; transno++) { - AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; - ExprState *filter = peraggstate->aggrefstate->aggfilter; - int numTransInputs = peraggstate->numTransInputs; + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + ExprState *filter = pertrans->aggfilter; + int numTransInputs = pertrans->numTransInputs; int i; TupleTableSlot *slot; @@ -754,12 +817,12 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) } /* Evaluate the current input expressions for this aggregate */ - slot = ExecProject(peraggstate->evalproj, NULL); + slot = ExecProject(pertrans->evalproj, NULL); - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { /* DISTINCT and/or ORDER BY case */ - Assert(slot->tts_nvalid == peraggstate->numInputs); + Assert(slot->tts_nvalid == pertrans->numInputs); /* * If the transfn is strict, we want to check for nullity before @@ -768,7 +831,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) * not numInputs, since nullity in columns used only for sorting * is not relevant here. */ - if (peraggstate->transfn.fn_strict) + if (pertrans->transfn.fn_strict) { for (i = 0; i < numTransInputs; i++) { @@ -782,18 +845,18 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) for (setno = 0; setno < numGroupingSets; setno++) { /* OK, put the tuple into the tuplesort object */ - if (peraggstate->numInputs == 1) - tuplesort_putdatum(peraggstate->sortstates[setno], + if (pertrans->numInputs == 1) + tuplesort_putdatum(pertrans->sortstates[setno], slot->tts_values[0], slot->tts_isnull[0]); else - tuplesort_puttupleslot(peraggstate->sortstates[setno], slot); + tuplesort_puttupleslot(pertrans->sortstates[setno], slot); } } else { /* We can apply the transition function immediately */ - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; /* Load values into fcinfo */ /* Start from 1, since the 0th arg will be the transition value */ @@ -806,11 +869,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)]; + AggStatePerGroup pergroupstate = &pergroup[transno + (setno * numTrans)]; aggstate->current_set = setno; - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); } } } @@ -841,7 +904,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) */ static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { Datum oldVal = (Datum) 0; @@ -849,14 +912,14 @@ process_ordered_aggregate_single(AggState *aggstate, bool haveOldVal = false; MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; MemoryContext oldContext; - bool isDistinct = (peraggstate->numDistinctCols > 0); - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + bool isDistinct = (pertrans->numDistinctCols > 0); + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; Datum *newVal; bool *isNull; - Assert(peraggstate->numDistinctCols < 2); + Assert(pertrans->numDistinctCols < 2); - tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); + tuplesort_performsort(pertrans->sortstates[aggstate->current_set]); /* Load the column into argument 1 (arg 0 will be transition value) */ newVal = fcinfo->arg + 1; @@ -868,7 +931,7 @@ process_ordered_aggregate_single(AggState *aggstate, * pfree them when they are no longer needed. */ - while (tuplesort_getdatum(peraggstate->sortstates[aggstate->current_set], + while (tuplesort_getdatum(pertrans->sortstates[aggstate->current_set], true, newVal, isNull)) { /* @@ -887,18 +950,18 @@ process_ordered_aggregate_single(AggState *aggstate, haveOldVal && ((oldIsNull && *isNull) || (!oldIsNull && !*isNull && - DatumGetBool(FunctionCall2(&peraggstate->equalfns[0], + DatumGetBool(FunctionCall2(&pertrans->equalfns[0], oldVal, *newVal))))) { /* equal to prior, so forget this one */ - if (!peraggstate->inputtypeByVal && !*isNull) + if (!pertrans->inputtypeByVal && !*isNull) pfree(DatumGetPointer(*newVal)); } else { - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); /* forget the old value, if any */ - if (!oldIsNull && !peraggstate->inputtypeByVal) + if (!oldIsNull && !pertrans->inputtypeByVal) pfree(DatumGetPointer(oldVal)); /* and remember the new one for subsequent equality checks */ oldVal = *newVal; @@ -909,11 +972,11 @@ process_ordered_aggregate_single(AggState *aggstate, MemoryContextSwitchTo(oldContext); } - if (!oldIsNull && !peraggstate->inputtypeByVal) + if (!oldIsNull && !pertrans->inputtypeByVal) pfree(DatumGetPointer(oldVal)); - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); - peraggstate->sortstates[aggstate->current_set] = NULL; + tuplesort_end(pertrans->sortstates[aggstate->current_set]); + pertrans->sortstates[aggstate->current_set] = NULL; } /* @@ -930,25 +993,25 @@ process_ordered_aggregate_single(AggState *aggstate, */ static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; - TupleTableSlot *slot1 = peraggstate->evalslot; - TupleTableSlot *slot2 = peraggstate->uniqslot; - int numTransInputs = peraggstate->numTransInputs; - int numDistinctCols = peraggstate->numDistinctCols; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + TupleTableSlot *slot1 = pertrans->evalslot; + TupleTableSlot *slot2 = pertrans->uniqslot; + int numTransInputs = pertrans->numTransInputs; + int numDistinctCols = pertrans->numDistinctCols; bool haveOldValue = false; int i; - tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); + tuplesort_performsort(pertrans->sortstates[aggstate->current_set]); ExecClearTuple(slot1); if (slot2) ExecClearTuple(slot2); - while (tuplesort_gettupleslot(peraggstate->sortstates[aggstate->current_set], + while (tuplesort_gettupleslot(pertrans->sortstates[aggstate->current_set], true, slot1)) { /* @@ -962,8 +1025,8 @@ process_ordered_aggregate_multi(AggState *aggstate, !haveOldValue || !execTuplesMatch(slot1, slot2, numDistinctCols, - peraggstate->sortColIdx, - peraggstate->equalfns, + pertrans->sortColIdx, + pertrans->equalfns, workcontext)) { /* Load values into fcinfo */ @@ -974,7 +1037,7 @@ process_ordered_aggregate_multi(AggState *aggstate, fcinfo->argnull[i + 1] = slot1->tts_isnull[i]; } - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); if (numDistinctCols > 0) { @@ -997,8 +1060,8 @@ process_ordered_aggregate_multi(AggState *aggstate, if (slot2) ExecClearTuple(slot2); - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); - peraggstate->sortstates[aggstate->current_set] = NULL; + tuplesort_end(pertrans->sortstates[aggstate->current_set]); + pertrans->sortstates[aggstate->current_set] = NULL; } /* @@ -1009,10 +1072,14 @@ process_ordered_aggregate_multi(AggState *aggstate, * * The finalfunction will be run, and the result delivered, in the * output-tuple context; caller's CurrentMemoryContext does not matter. + * + * The finalfn uses the state as set in the transno. This also might be + * being used by another aggregate function, so it's important that we do + * nothing destructive here. */ static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull) { @@ -1021,6 +1088,7 @@ finalize_aggregate(AggState *aggstate, MemoryContext oldContext; int i; ListCell *lc; + AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno]; oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); @@ -1031,7 +1099,7 @@ finalize_aggregate(AggState *aggstate, * for the transition state value. */ i = 1; - foreach(lc, peraggstate->aggrefstate->aggdirectargs) + foreach(lc, pertrans->aggdirectargs) { ExprState *expr = (ExprState *) lfirst(lc); @@ -1046,16 +1114,16 @@ finalize_aggregate(AggState *aggstate, /* * Apply the agg's finalfn if one is provided, else return transValue. */ - if (OidIsValid(peraggstate->finalfn_oid)) + if (OidIsValid(peragg->finalfn_oid)) { - int numFinalArgs = peraggstate->numFinalArgs; + int numFinalArgs = peragg->numFinalArgs; - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; - InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn, + InitFunctionCallInfoData(fcinfo, &peragg->finalfn, numFinalArgs, - peraggstate->aggCollation, + pertrans->aggCollation, (void *) aggstate, NULL); /* Fill in the transition state value */ @@ -1082,7 +1150,7 @@ finalize_aggregate(AggState *aggstate, *resultVal = FunctionCallInvoke(&fcinfo); *resultIsNull = fcinfo.isnull; } - aggstate->curperagg = NULL; + aggstate->curpertrans = NULL; } else { @@ -1093,12 +1161,12 @@ finalize_aggregate(AggState *aggstate, /* * If result is pass-by-ref, make sure it is in the right context. */ - if (!peraggstate->resulttypeByVal && !*resultIsNull && + if (!peragg->resulttypeByVal && !*resultIsNull && !MemoryContextContains(CurrentMemoryContext, DatumGetPointer(*resultVal))) *resultVal = datumCopy(*resultVal, - peraggstate->resulttypeByVal, - peraggstate->resulttypeLen); + peragg->resulttypeByVal, + peragg->resulttypeLen); MemoryContextSwitchTo(oldContext); } @@ -1173,7 +1241,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet */ static void finalize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, + AggStatePerAgg peraggs, AggStatePerGroup pergroup, int currentSet) { @@ -1189,26 +1257,28 @@ finalize_aggregates(AggState *aggstate, for (aggno = 0; aggno < aggstate->numaggs; aggno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerAgg peragg = &peraggs[aggno]; + int transno = peragg->transno; + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))]; + pergroupstate = &pergroup[transno + (currentSet * (aggstate->numtrans))]; - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED); - if (peraggstate->numInputs == 1) + if (pertrans->numInputs == 1) process_ordered_aggregate_single(aggstate, - peraggstate, + pertrans, pergroupstate); else process_ordered_aggregate_multi(aggstate, - peraggstate, + pertrans, pergroupstate); } - finalize_aggregate(aggstate, peraggstate, pergroupstate, + finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); } } @@ -1428,7 +1498,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot) if (isnew) { /* initialize aggregates for new tuple group */ - initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0); + initialize_aggregates(aggstate, entry->pergroup, 0); } return entry; @@ -1716,7 +1786,7 @@ agg_retrieve_direct(AggState *aggstate) /* * Initialize working state for a new input tuple group. */ - initialize_aggregates(aggstate, peragg, pergroup, numReset); + initialize_aggregates(aggstate, pergroup, numReset); if (aggstate->grp_firstTuple != NULL) { @@ -1945,17 +2015,18 @@ AggState * ExecInitAgg(Agg *node, EState *estate, int eflags) { AggState *aggstate; - AggStatePerAgg peragg; + AggStatePerAgg peraggs; + AggStatePerTrans pertransstates; Plan *outerPlan; ExprContext *econtext; int numaggs, + transno, aggno; int phase; ListCell *l; Bitmapset *all_grouped_cols = NULL; int numGroupingSets = 1; int numPhases; - int currentsortno = 0; int i = 0; int j = 0; @@ -1971,12 +2042,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->aggs = NIL; aggstate->numaggs = 0; + aggstate->numtrans = 0; aggstate->maxsets = 0; aggstate->hashfunctions = NULL; aggstate->projected_set = -1; aggstate->current_set = 0; aggstate->peragg = NULL; - aggstate->curperagg = NULL; + aggstate->pertrans = NULL; + aggstate->curpertrans = NULL; aggstate->agg_done = false; aggstate->input_done = false; aggstate->pergroup = NULL; @@ -2209,8 +2282,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs); econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs); - peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); - aggstate->peragg = peragg; + peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); + pertransstates = (AggStatePerTrans) palloc0(sizeof(AggStatePerTransData) * numaggs); + + aggstate->peragg = peraggs; + aggstate->pertrans = pertransstates; if (node->aggstrategy == AGG_HASHED) { @@ -2230,71 +2306,86 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->pergroup = pergroup; } - /* + /* ----------------- * Perform lookups of aggregate function info, and initialize the - * unchanging fields of the per-agg data. We also detect duplicate - * aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When - * duplicates are detected, we only make an AggStatePerAgg struct for the - * first one. The clones are simply pointed at the same result entry by - * giving them duplicate aggno values. + * unchanging fields of the per-agg and per-trans data. + * + * We try to optimize by detecting duplicate aggregate functions so that + * their state and final values are re-used, rather than needlessly being + * re-calculated independently. We also detect aggregates that are not + * the same, but which can share the same transition state. + * + * Scenarios: + * + * 1. An aggregate function appears more than once in query: + * + * SELECT SUM(x) FROM ... HAVING SUM(x) > 0 + * + * Since the aggregates are the identical, we only need to calculate + * the calculate it once. Both aggregates will share the same 'aggno' + * value. + * + * 2. Two different aggregate functions appear in the query, but the + * aggregates have the same transition function and initial value, but + * different final function: + * + * SELECT SUM(x), AVG(x) FROM ... + * + * In this case we must create a new peragg for the varying aggregate, + * and need to call the final functions separately, but can share the + * same transition state. + * + * For either of these optimizations to be valid, the aggregate's + * arguments must be the same, including any modifiers such as ORDER BY, + * DISTINCT and FILTER, and they mustn't contain any volatile functions. + * ----------------- */ aggno = -1; + transno = -1; foreach(l, aggstate->aggs) { AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l); Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr; - AggStatePerAgg peraggstate; + AggStatePerAgg peragg; + AggStatePerTrans pertrans; + int existing_aggno; + int existing_transno; + List *same_input_transnos; Oid inputTypes[FUNC_MAX_ARGS]; int numArguments; int numDirectArgs; - int numInputs; - int numSortCols; - int numDistinctCols; - List *sortlist; HeapTuple aggTuple; Form_pg_aggregate aggform; - Oid aggtranstype; AclResult aclresult; Oid transfn_oid, finalfn_oid; - Expr *transfnexpr, - *finalfnexpr; + Expr *finalfnexpr; + Oid aggtranstype; Datum textInitVal; - int i; - ListCell *lc; + Datum initValue; + bool initValueIsNull; /* Planner should have assigned aggregate to correct level */ Assert(aggref->agglevelsup == 0); - /* Look for a previous duplicate aggregate */ - for (i = 0; i <= aggno; i++) + /* 1. Check for already processed aggs which can be re-used */ + existing_aggno = find_compatible_peragg(aggref, aggstate, aggno, + &same_input_transnos); + if (existing_aggno != -1) { - if (equal(aggref, peragg[i].aggref) && - !contain_volatile_functions((Node *) aggref)) - break; - } - if (i <= aggno) - { - /* Found a match to an existing entry, so just mark it */ - aggrefstate->aggno = i; + /* + * Existing compatible agg found. so just point the Aggref to the + * same per-agg struct. + */ + aggrefstate->aggno = existing_aggno; continue; } - /* Nope, so assign a new PerAgg record */ - peraggstate = &peragg[++aggno]; - /* Mark Aggref state node with assigned index in the result array */ + peragg = &peraggs[++aggno]; + peragg->aggref = aggref; aggrefstate->aggno = aggno; - /* Begin filling in the peraggstate data */ - peraggstate->aggrefstate = aggrefstate; - peraggstate->aggref = aggref; - peraggstate->sortstates = (Tuplesortstate **) - palloc0(sizeof(Tuplesortstate *) * numGroupingSets); - - for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++) - peraggstate->sortstates[currentsortno] = NULL; - /* Fetch the pg_aggregate row */ aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(aggref->aggfnoid)); @@ -2311,8 +2402,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); - peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn; - peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + transfn_oid = aggform->aggtransfn; + peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; /* Check that aggregate owner has permission to call component fns */ { @@ -2350,74 +2441,43 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * agg accepts ANY or a polymorphic type. */ numArguments = get_aggregate_argtypes(aggref, inputTypes); - peraggstate->numArguments = numArguments; /* Count the "direct" arguments, if any */ numDirectArgs = list_length(aggref->aggdirectargs); - /* Count the number of aggregated input columns */ - numInputs = list_length(aggref->args); - peraggstate->numInputs = numInputs; - - /* Detect how many arguments to pass to the transfn */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) - peraggstate->numTransInputs = numInputs; - else - peraggstate->numTransInputs = numArguments; - - /* Detect how many arguments to pass to the finalfn */ - if (aggform->aggfinalextra) - peraggstate->numFinalArgs = numArguments + 1; - else - peraggstate->numFinalArgs = numDirectArgs + 1; - /* resolve actual type of transition state, if polymorphic */ aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, aggform->aggtranstype, inputTypes, numArguments); - /* build expression trees using actual argument & result types */ - build_aggregate_fnexprs(inputTypes, - numArguments, - numDirectArgs, - peraggstate->numFinalArgs, - aggref->aggvariadic, - aggtranstype, - aggref->aggtype, - aggref->inputcollid, - transfn_oid, - InvalidOid, /* invtrans is not needed here */ - finalfn_oid, - &transfnexpr, - NULL, - &finalfnexpr); - - /* set up infrastructure for calling the transfn and finalfn */ - fmgr_info(transfn_oid, &peraggstate->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn); + /* Detect how many arguments to pass to the finalfn */ + if (aggform->aggfinalextra) + peragg->numFinalArgs = numArguments + 1; + else + peragg->numFinalArgs = numDirectArgs + 1; + /* + * build expression trees using actual argument & result types for the + * finalfn, if it exists + */ if (OidIsValid(finalfn_oid)) { - fmgr_info(finalfn_oid, &peraggstate->finalfn); - fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); + build_aggregate_finalfn_expr(inputTypes, + peragg->numFinalArgs, + aggtranstype, + aggref->aggtype, + aggref->inputcollid, + finalfn_oid, + &finalfnexpr); + fmgr_info(finalfn_oid, &peragg->finalfn); + fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn); } - peraggstate->aggCollation = aggref->inputcollid; - - InitFunctionCallInfoData(peraggstate->transfn_fcinfo, - &peraggstate->transfn, - peraggstate->numTransInputs + 1, - peraggstate->aggCollation, - (void *) aggstate, NULL); - - /* get info about relevant datatypes */ + /* get info about the result type's datatype */ get_typlenbyval(aggref->aggtype, - &peraggstate->resulttypeLen, - &peraggstate->resulttypeByVal); - get_typlenbyval(aggtranstype, - &peraggstate->transtypeLen, - &peraggstate->transtypeByVal); + &peragg->resulttypeLen, + &peragg->resulttypeByVal); /* * initval is potentially null, so don't try to access it as a struct @@ -2425,161 +2485,292 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, Anum_pg_aggregate_agginitval, - &peraggstate->initValueIsNull); - - if (peraggstate->initValueIsNull) - peraggstate->initValue = (Datum) 0; + &initValueIsNull); + if (initValueIsNull) + initValue = (Datum) 0; else - peraggstate->initValue = GetAggInitVal(textInitVal, - aggtranstype); + initValue = GetAggInitVal(textInitVal, aggtranstype); /* - * If the transfn is strict and the initval is NULL, make sure input - * type and transtype are the same (or at least binary-compatible), so - * that it's OK to use the first aggregated input value as the initial - * transValue. This should have been checked at agg definition time, - * but we must check again in case the transfn's strictness property - * has been changed. - */ - if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull) - { - if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], aggtranstype)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate %u needs to have compatible input type and transition type", - aggref->aggfnoid))); - } - - /* - * Get a tupledesc corresponding to the aggregated inputs (including - * sort expressions) of the agg. - */ - peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false); - - /* Create slot we're going to do argument evaluation in */ - peraggstate->evalslot = ExecInitExtraTupleSlot(estate); - ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc); - - /* Set up projection info for evaluation */ - peraggstate->evalproj = ExecBuildProjectionInfo(aggrefstate->args, - aggstate->tmpcontext, - peraggstate->evalslot, - NULL); - - /* - * If we're doing either DISTINCT or ORDER BY for a plain agg, then we - * have a list of SortGroupClause nodes; fish out the data in them and - * stick them into arrays. We ignore ORDER BY for an ordered-set agg, - * however; the agg's transfn and finalfn are responsible for that. + * 2. Build working state for invoking the transition function, or + * look up previously initialized working state, if we can share it. * - * Note that by construction, if there is a DISTINCT clause then the - * ORDER BY clause is a prefix of it (see transformDistinctClause). + * find_compatible_peragg() already collected a list of per-Trans's + * with the same inputs. Check if any of them have the same transition + * function and initial value. */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + existing_transno = find_compatible_pertrans(aggstate, aggref, + transfn_oid, aggtranstype, + initValue, initValueIsNull, + same_input_transnos); + if (existing_transno != -1) { - sortlist = NIL; - numSortCols = numDistinctCols = 0; - } - else if (aggref->aggdistinct) - { - sortlist = aggref->aggdistinct; - numSortCols = numDistinctCols = list_length(sortlist); - Assert(numSortCols >= list_length(aggref->aggorder)); + /* + * Existing compatible trans found, so just point the 'peragg' to + * the same per-trans struct. + */ + pertrans = &pertransstates[existing_transno]; + peragg->transno = existing_transno; } else { - sortlist = aggref->aggorder; - numSortCols = list_length(sortlist); - numDistinctCols = 0; + pertrans = &pertransstates[++transno]; + build_pertrans_for_aggref(pertrans, aggstate, estate, + aggref, transfn_oid, aggtranstype, + initValue, initValueIsNull, + inputTypes, numArguments); + peragg->transno = transno; } - - peraggstate->numSortCols = numSortCols; - peraggstate->numDistinctCols = numDistinctCols; - - if (numSortCols > 0) - { - /* - * We don't implement DISTINCT or ORDER BY aggs in the HASHED case - * (yet) - */ - Assert(node->aggstrategy != AGG_HASHED); - - /* If we have only one input, we need its len/byval info. */ - if (numInputs == 1) - { - get_typlenbyval(inputTypes[numDirectArgs], - &peraggstate->inputtypeLen, - &peraggstate->inputtypeByVal); - } - else if (numDistinctCols > 0) - { - /* we will need an extra slot to store prior values */ - peraggstate->uniqslot = ExecInitExtraTupleSlot(estate); - ExecSetSlotDescriptor(peraggstate->uniqslot, - peraggstate->evaldesc); - } - - /* Extract the sort information for use later */ - peraggstate->sortColIdx = - (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber)); - peraggstate->sortOperators = - (Oid *) palloc(numSortCols * sizeof(Oid)); - peraggstate->sortCollations = - (Oid *) palloc(numSortCols * sizeof(Oid)); - peraggstate->sortNullsFirst = - (bool *) palloc(numSortCols * sizeof(bool)); - - i = 0; - foreach(lc, sortlist) - { - SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); - TargetEntry *tle = get_sortgroupclause_tle(sortcl, - aggref->args); - - /* the parser should have made sure of this */ - Assert(OidIsValid(sortcl->sortop)); - - peraggstate->sortColIdx[i] = tle->resno; - peraggstate->sortOperators[i] = sortcl->sortop; - peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr); - peraggstate->sortNullsFirst[i] = sortcl->nulls_first; - i++; - } - Assert(i == numSortCols); - } - - if (aggref->aggdistinct) - { - Assert(numArguments > 0); - - /* - * We need the equal function for each DISTINCT comparison we will - * make. - */ - peraggstate->equalfns = - (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo)); - - i = 0; - foreach(lc, aggref->aggdistinct) - { - SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); - - fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]); - i++; - } - Assert(i == numDistinctCols); - } - ReleaseSysCache(aggTuple); } - /* Update numaggs to match number of unique aggregates found */ + /* + * Update numaggs to match the number of unique aggregates found. Also set + * numstates to the number of unique aggregate states found. + */ aggstate->numaggs = aggno + 1; + aggstate->numtrans = transno + 1; return aggstate; } +/* + * Build the state needed to calculate a state value for an aggregate. + * + * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate + * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest + * of the arguments could be calculated from 'aggref', but the caller has + * calculated them already, so might as well pass them. + */ +static void +build_pertrans_for_aggref(AggStatePerTrans pertrans, + AggState *aggstate, EState *estate, + Aggref *aggref, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + Oid *inputTypes, int numArguments) +{ + int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *transfnexpr; + ListCell *lc; + int numInputs; + int numDirectArgs; + List *sortlist; + int numSortCols; + int numDistinctCols; + int naggs; + int i; + + /* Begin filling in the pertrans data */ + pertrans->aggref = aggref; + pertrans->aggCollation = aggref->inputcollid; + pertrans->transfn_oid = aggtransfn; + pertrans->initValue = initValue; + pertrans->initValueIsNull = initValueIsNull; + + /* Count the "direct" arguments, if any */ + numDirectArgs = list_length(aggref->aggdirectargs); + + /* Count the number of aggregated input columns */ + pertrans->numInputs = numInputs = list_length(aggref->args); + + pertrans->aggtranstype = aggtranstype; + + /* Detect how many arguments to pass to the transfn */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + pertrans->numTransInputs = numInputs; + else + pertrans->numTransInputs = numArguments; + + /* + * Set up infrastructure for calling the transfn + */ + build_aggregate_transfn_expr(inputTypes, + numArguments, + numDirectArgs, + aggref->aggvariadic, + aggtranstype, + aggref->inputcollid, + aggtransfn, + InvalidOid, /* invtrans is not needed here */ + &transfnexpr, + NULL); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + pertrans->numTransInputs + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + /* + * If the transfn is strict and the initval is NULL, make sure input type + * and transtype are the same (or at least binary-compatible), so that + * it's OK to use the first aggregated input value as the initial + * transValue. This should have been checked at agg definition time, but + * we must check again in case the transfn's strictness property has been + * changed. + */ + if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + { + if (numArguments <= numDirectArgs || + !IsBinaryCoercible(inputTypes[numDirectArgs], + aggtranstype)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate %u needs to have compatible input type and transition type", + aggref->aggfnoid))); + } + + /* get info about the state value's datatype */ + get_typlenbyval(aggtranstype, + &pertrans->transtypeLen, + &pertrans->transtypeByVal); + + /* + * Get a tupledesc corresponding to the aggregated inputs (including sort + * expressions) of the agg. + */ + pertrans->evaldesc = ExecTypeFromTL(aggref->args, false); + + /* Create slot we're going to do argument evaluation in */ + pertrans->evalslot = ExecInitExtraTupleSlot(estate); + ExecSetSlotDescriptor(pertrans->evalslot, pertrans->evaldesc); + + /* Initialize the input and FILTER expressions */ + naggs = aggstate->numaggs; + pertrans->aggfilter = ExecInitExpr(aggref->aggfilter, + (PlanState *) aggstate); + pertrans->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs, + (PlanState *) aggstate); + pertrans->args = (List *) ExecInitExpr((Expr *) aggref->args, + (PlanState *) aggstate); + + /* + * Complain if the aggregate's arguments contain any aggregates; nested + * agg functions are semantically nonsensical. (This should have been + * caught earlier, but we defend against it here anyway.) + */ + if (naggs != aggstate->numaggs) + ereport(ERROR, + (errcode(ERRCODE_GROUPING_ERROR), + errmsg("aggregate function calls cannot be nested"))); + + /* Set up projection info for evaluation */ + pertrans->evalproj = ExecBuildProjectionInfo(pertrans->args, + aggstate->tmpcontext, + pertrans->evalslot, + NULL); + + /* + * If we're doing either DISTINCT or ORDER BY for a plain agg, then we + * have a list of SortGroupClause nodes; fish out the data in them and + * stick them into arrays. We ignore ORDER BY for an ordered-set agg, + * however; the agg's transfn and finalfn are responsible for that. + * + * Note that by construction, if there is a DISTINCT clause then the ORDER + * BY clause is a prefix of it (see transformDistinctClause). + */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + { + sortlist = NIL; + numSortCols = numDistinctCols = 0; + } + else if (aggref->aggdistinct) + { + sortlist = aggref->aggdistinct; + numSortCols = numDistinctCols = list_length(sortlist); + Assert(numSortCols >= list_length(aggref->aggorder)); + } + else + { + sortlist = aggref->aggorder; + numSortCols = list_length(sortlist); + numDistinctCols = 0; + } + + pertrans->numSortCols = numSortCols; + pertrans->numDistinctCols = numDistinctCols; + + if (numSortCols > 0) + { + /* + * We don't implement DISTINCT or ORDER BY aggs in the HASHED case + * (yet) + */ + Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED); + + /* If we have only one input, we need its len/byval info. */ + if (numInputs == 1) + { + get_typlenbyval(inputTypes[numDirectArgs], + &pertrans->inputtypeLen, + &pertrans->inputtypeByVal); + } + else if (numDistinctCols > 0) + { + /* we will need an extra slot to store prior values */ + pertrans->uniqslot = ExecInitExtraTupleSlot(estate); + ExecSetSlotDescriptor(pertrans->uniqslot, + pertrans->evaldesc); + } + + /* Extract the sort information for use later */ + pertrans->sortColIdx = + (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber)); + pertrans->sortOperators = + (Oid *) palloc(numSortCols * sizeof(Oid)); + pertrans->sortCollations = + (Oid *) palloc(numSortCols * sizeof(Oid)); + pertrans->sortNullsFirst = + (bool *) palloc(numSortCols * sizeof(bool)); + + i = 0; + foreach(lc, sortlist) + { + SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); + TargetEntry *tle = get_sortgroupclause_tle(sortcl, aggref->args); + + /* the parser should have made sure of this */ + Assert(OidIsValid(sortcl->sortop)); + + pertrans->sortColIdx[i] = tle->resno; + pertrans->sortOperators[i] = sortcl->sortop; + pertrans->sortCollations[i] = exprCollation((Node *) tle->expr); + pertrans->sortNullsFirst[i] = sortcl->nulls_first; + i++; + } + Assert(i == numSortCols); + } + + if (aggref->aggdistinct) + { + Assert(numArguments > 0); + + /* + * We need the equal function for each DISTINCT comparison we will + * make. + */ + pertrans->equalfns = + (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo)); + + i = 0; + foreach(lc, aggref->aggdistinct) + { + SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); + + fmgr_info(get_opcode(sortcl->eqop), &pertrans->equalfns[i]); + i++; + } + Assert(i == numDistinctCols); + } + + pertrans->sortstates = (Tuplesortstate **) + palloc0(sizeof(Tuplesortstate *) * numGroupingSets); +} + + static Datum GetAggInitVal(Datum textInitVal, Oid transtype) { @@ -2596,11 +2787,130 @@ GetAggInitVal(Datum textInitVal, Oid transtype) return initVal; } +/* + * find_compatible_peragg - search for a previously initialized per-Agg struct + * + * Searches the previously looked at aggregates to find one which is compatible + * with this one, with the same input parameters. If no compatible aggregate + * can be found, returns -1. + * + * As a side-effect, this also collects a list of existing per-Trans structs + * with matching inputs. If no identical Aggref is found, the list is passed + * later to find_compatible_perstate, to see if we can at least reuse the + * state value of another aggregate. + */ +static int +find_compatible_peragg(Aggref *newagg, AggState *aggstate, + int lastaggno, List **same_input_transnos) +{ + int aggno; + AggStatePerAgg peraggs; + + *same_input_transnos = NIL; + + /* we mustn't reuse the aggref if it contains volatile function calls */ + if (contain_volatile_functions((Node *) newagg)) + return -1; + + peraggs = aggstate->peragg; + + /* + * Search through the list of already seen aggregates. If we find an + * existing aggregate with the same aggregate function and input + * parameters as an existing one, then we can re-use that one. While + * searching, we'll also collect a list of Aggrefs with the same input + * parameters. If no matching Aggref is found, the caller can potentially + * still re-use the transition state of one of them. + */ + for (aggno = 0; aggno <= lastaggno; aggno++) + { + AggStatePerAgg peragg; + Aggref *existingRef; + + peragg = &peraggs[aggno]; + existingRef = peragg->aggref; + + /* all of the following must be the same or it's no match */ + if (newagg->inputcollid != existingRef->inputcollid || + newagg->aggstar != existingRef->aggstar || + newagg->aggvariadic != existingRef->aggvariadic || + newagg->aggkind != existingRef->aggkind || + !equal(newagg->aggdirectargs, existingRef->aggdirectargs) || + !equal(newagg->args, existingRef->args) || + !equal(newagg->aggorder, existingRef->aggorder) || + !equal(newagg->aggdistinct, existingRef->aggdistinct) || + !equal(newagg->aggfilter, existingRef->aggfilter)) + continue; + + /* if it's the same aggregate function then report exact match */ + if (newagg->aggfnoid == existingRef->aggfnoid && + newagg->aggtype == existingRef->aggtype && + newagg->aggcollid == existingRef->aggcollid) + { + list_free(*same_input_transnos); + *same_input_transnos = NIL; + return aggno; + } + + /* + * Not identical, but it had the same inputs. Return it to the caller, + * in case we can re-use its per-trans state. + */ + *same_input_transnos = lappend_int(*same_input_transnos, + peragg->transno); + } + + return -1; +} + +/* + * find_compatible_pertrans - search for a previously initialized per-Trans + * struct + * + * Searches the list of transnos for a per-Trans struct with the same + * transition state and initial condition. (The inputs have already been + * verified to match.) + */ +static int +find_compatible_pertrans(AggState *aggstate, Aggref *newagg, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + List *transnos) +{ + ListCell *lc; + + foreach(lc, transnos) + { + int transno = lfirst_int(lc); + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + + /* + * if the transfns or transition state types are not the same then the + * state can't be shared. + */ + if (aggtransfn != pertrans->transfn_oid || + aggtranstype != pertrans->aggtranstype) + continue; + + /* Check that the initial condition matches, too. */ + if (initValueIsNull && pertrans->initValueIsNull) + return transno; + + if (!initValueIsNull && !pertrans->initValueIsNull && + datumIsEqual(initValue, pertrans->initValue, + pertrans->transtypeByVal, pertrans->transtypeLen)) + { + return transno; + } + } + return -1; +} + void ExecEndAgg(AggState *node) { PlanState *outerPlan; - int aggno; + int transno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2611,14 +2921,14 @@ ExecEndAgg(AggState *node) if (node->sort_out) tuplesort_end(node->sort_out); - for (aggno = 0; aggno < node->numaggs; aggno++) + for (transno = 0; transno < node->numtrans; transno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerTrans pertrans = &node->pertrans[transno]; for (setno = 0; setno < numGroupingSets; setno++) { - if (peraggstate->sortstates[setno]) - tuplesort_end(peraggstate->sortstates[setno]); + if (pertrans->sortstates[setno]) + tuplesort_end(pertrans->sortstates[setno]); } } @@ -2646,7 +2956,7 @@ ExecReScanAgg(AggState *node) ExprContext *econtext = node->ss.ps.ps_ExprContext; PlanState *outerPlan = outerPlanState(node); Agg *aggnode = (Agg *) node->ss.ps.plan; - int aggno; + int transno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2678,16 +2988,16 @@ ExecReScanAgg(AggState *node) } /* Make sure we have closed any open tuplesorts */ - for (aggno = 0; aggno < node->numaggs; aggno++) + for (transno = 0; transno < node->numtrans; transno++) { for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerTrans pertrans = &node->pertrans[transno]; - if (peraggstate->sortstates[setno]) + if (pertrans->sortstates[setno]) { - tuplesort_end(peraggstate->sortstates[setno]); - peraggstate->sortstates[setno] = NULL; + tuplesort_end(pertrans->sortstates[setno]); + pertrans->sortstates[setno] = NULL; } } } @@ -2811,10 +3121,12 @@ AggGetAggref(FunctionCallInfo fcinfo) { if (fcinfo->context && IsA(fcinfo->context, AggState)) { - AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg; + AggStatePerTrans curpertrans; - if (curperagg) - return curperagg->aggref; + curpertrans = ((AggState *) fcinfo->context)->curpertrans; + + if (curpertrans) + return curpertrans->aggref; } return NULL; } diff --git a/src/backend/executor/nodeWindowAgg.c b/src/backend/executor/nodeWindowAgg.c index ecf96f8c19..c371d4db14 100644 --- a/src/backend/executor/nodeWindowAgg.c +++ b/src/backend/executor/nodeWindowAgg.c @@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, numArguments); /* build expression trees using actual argument & result types */ - build_aggregate_fnexprs(inputTypes, - numArguments, - 0, /* no ordered-set window functions yet */ - peraggstate->numFinalArgs, - false, /* no variadic window functions yet */ - aggtranstype, - wfunc->wintype, - wfunc->inputcollid, - transfn_oid, - invtransfn_oid, - finalfn_oid, - &transfnexpr, - &invtransfnexpr, - &finalfnexpr); + build_aggregate_transfn_expr(inputTypes, + numArguments, + 0, /* no ordered-set window functions yet */ + false, /* no variadic window functions yet */ + wfunc->wintype, + wfunc->inputcollid, + transfn_oid, + invtransfn_oid, + &transfnexpr, + &invtransfnexpr); /* set up infrastructure for calling the transfn(s) and finalfn */ fmgr_info(transfn_oid, &peraggstate->transfn); @@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, if (OidIsValid(finalfn_oid)) { + build_aggregate_finalfn_expr(inputTypes, + peraggstate->numFinalArgs, + aggtranstype, + wfunc->wintype, + wfunc->inputcollid, + finalfn_oid, + &finalfnexpr); fmgr_info(finalfn_oid, &peraggstate->finalfn); fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); } diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 3846b569d6..5b0d568478 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid, } /* - * Create expression trees for the transition and final functions - * of an aggregate. These are needed so that polymorphic functions - * can be used within an aggregate --- without the expression trees, - * such functions would not know the datatypes they are supposed to use. - * (The trees will never actually be executed, however, so we can skimp - * a bit on correctness.) + * Create an expression tree for the transition function of an aggregate. + * This is needed so that polymorphic functions can be used within an + * aggregate --- without the expression tree, such functions would not know + * the datatypes they are supposed to use. (The trees will never actually + * be executed, however, so we can skimp a bit on correctness.) * - * agg_input_types, agg_state_type, agg_result_type identify the input, - * transition, and result types of the aggregate. These should all be - * resolved to actual types (ie, none should ever be ANYELEMENT etc). + * agg_input_types and agg_state_type identifies the input types of the + * aggregate. These should be resolved to actual types (ie, none should + * ever be ANYELEMENT etc). * agg_input_collation is the aggregate function's input collation. * * For an ordered-set aggregate, remember that agg_input_types describes * the direct arguments followed by the aggregated arguments. * - * transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be - * called; the latter two may be InvalidOid. + * transfn_oid and invtransfn_oid identify the funcs to be called; the + * latter may be InvalidOid, however if invtransfn_oid is set then + * transfn_oid must also be set. * * Pointers to the constructed trees are returned into *transfnexpr, - * *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn, - * the respective pointers are set to NULL. Since use of the invtransfn is - * optional, NULL may be passed for invtransfnexpr. + * *invtransfnexpr. If there is no invtransfn, the respective pointer is set + * to NULL. Since use of the invtransfn is optional, NULL may be passed for + * invtransfnexpr. */ void -build_aggregate_fnexprs(Oid *agg_input_types, - int agg_num_inputs, - int agg_num_direct_inputs, - int num_finalfn_inputs, - bool agg_variadic, - Oid agg_state_type, - Oid agg_result_type, - Oid agg_input_collation, - Oid transfn_oid, - Oid invtransfn_oid, - Oid finalfn_oid, - Expr **transfnexpr, - Expr **invtransfnexpr, - Expr **finalfnexpr) +build_aggregate_transfn_expr(Oid *agg_input_types, + int agg_num_inputs, + int agg_num_direct_inputs, + bool agg_variadic, + Oid agg_state_type, + Oid agg_input_collation, + Oid transfn_oid, + Oid invtransfn_oid, + Expr **transfnexpr, + Expr **invtransfnexpr) { Param *argp; List *args; @@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types, else *invtransfnexpr = NULL; } +} - /* see if we have a final function */ - if (!OidIsValid(finalfn_oid)) - { - *finalfnexpr = NULL; - return; - } +/* + * Like build_aggregate_transfn_expr, but creates an expression tree for the + * final function of an aggregate, rather than the transition function. + */ +void +build_aggregate_finalfn_expr(Oid *agg_input_types, + int num_finalfn_inputs, + Oid agg_state_type, + Oid agg_result_type, + Oid agg_input_collation, + Oid finalfn_oid, + Expr **finalfnexpr) +{ + Param *argp; + List *args; + int i; /* * Build expr tree for final function diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h index 303fc3c1c7..5796de861c 100644 --- a/src/include/nodes/execnodes.h +++ b/src/include/nodes/execnodes.h @@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState typedef struct AggrefExprState { ExprState xprstate; - List *aggdirectargs; /* states of direct-argument expressions */ - List *args; /* states of aggregated-argument expressions */ - ExprState *aggfilter; /* state of FILTER expression, if any */ int aggno; /* ID number for agg within its plan node */ } AggrefExprState; @@ -1825,6 +1822,7 @@ typedef struct GroupState */ /* these structs are private in nodeAgg.c: */ typedef struct AggStatePerAggData *AggStatePerAgg; +typedef struct AggStatePerTransData *AggStatePerTrans; typedef struct AggStatePerGroupData *AggStatePerGroup; typedef struct AggStatePerPhaseData *AggStatePerPhase; @@ -1833,14 +1831,16 @@ typedef struct AggState ScanState ss; /* its first field is NodeTag */ List *aggs; /* all Aggref nodes in targetlist & quals */ int numaggs; /* length of list (could be zero!) */ + int numtrans; /* number of pertrans items */ AggStatePerPhase phase; /* pointer to current phase data */ int numphases; /* number of phases */ int current_phase; /* current phase number */ FmgrInfo *hashfunctions; /* per-grouping-field hash fns */ AggStatePerAgg peragg; /* per-Aggref information */ + AggStatePerTrans pertrans; /* per-Trans state information */ ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */ ExprContext *tmpcontext; /* econtext for input expressions */ - AggStatePerAgg curperagg; /* identifies currently active aggregate */ + AggStatePerTrans curpertrans; /* currently active trans state */ bool input_done; /* indicates end of input */ bool agg_done; /* indicates completion of Agg scan */ int projected_set; /* The last projected grouping set */ diff --git a/src/include/parser/parse_agg.h b/src/include/parser/parse_agg.h index 6a5f9bbdf1..e2b3894c28 100644 --- a/src/include/parser/parse_agg.h +++ b/src/include/parser/parse_agg.h @@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid, Oid *inputTypes, int numArguments); -extern void build_aggregate_fnexprs(Oid *agg_input_types, +extern void build_aggregate_transfn_expr(Oid *agg_input_types, int agg_num_inputs, int agg_num_direct_inputs, - int num_finalfn_inputs, bool agg_variadic, Oid agg_state_type, - Oid agg_result_type, Oid agg_input_collation, Oid transfn_oid, Oid invtransfn_oid, - Oid finalfn_oid, Expr **transfnexpr, - Expr **invtransfnexpr, + Expr **invtransfnexpr); + +extern void build_aggregate_finalfn_expr(Oid *agg_input_types, + int num_finalfn_inputs, + Oid agg_state_type, + Oid agg_result_type, + Oid agg_input_collation, + Oid finalfn_oid, Expr **finalfnexpr); #endif /* PARSE_AGG_H */ diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index 8852051e93..de826b5e50 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl; -4567890123456789 (1 row) +-- test aggregates with common transition functions share the same states +begin work; +create type avg_state as (total bigint, count bigint); +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); +-- aggregate state should be shared as aggs are the same. +select my_avg(one),my_avg(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_avg +--------+-------- + 2 | 2 +(1 row) + +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- shouldn't share states due to the filter clause not matching. +select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 3 | 4 +(1 row) + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 2 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 4 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 6 +(1 row) + +-- test that aggs with the same sfunc and initcond share the same agg state +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(10,0)' +); +create aggregate my_avg_init2(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(4,0)' +); +-- state should be shared if INITCONDs are matching +select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_sum_init | my_avg_init +-------------+------------- + 14 | 7 +(1 row) + +-- Varying INITCONDs should cause the states not to be shared. +select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 3 + my_sum_init | my_avg_init2 +-------------+-------------- + 14 | 4 +(1 row) + +rollback; +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); +NOTICE: sum_transfn called with 1 +NOTICE: sum_transfn called with 2 +NOTICE: sum_transfn called with 3 +NOTICE: sum_transfn called with 4 + my_sum | my_half_sum +--------+------------- + 10 | 5 +(1 row) + +rollback; diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index a84327d24c..8d501dc008 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -590,3 +590,168 @@ drop view aggordview1; -- variadic aggregates select least_agg(q1,q2) from int8_tbl; select least_agg(variadic array[q1,q2]) from int8_tbl; + + +-- test aggregates with common transition functions share the same states +begin work; + +create type avg_state as (total bigint, count bigint); + +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; + +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; + +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); + +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); + +-- aggregate state should be shared as aggs are the same. +select my_avg(one),my_avg(one) from (values(1),(3)) t(one); + +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1),(3)) t(one); + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one); + +-- shouldn't share states due to the filter clause not matching. +select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one); + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); + +-- test that aggs with the same sfunc and initcond share the same agg state +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); + +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(10,0)' +); + +create aggregate my_avg_init2(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(4,0)' +); + +-- state should be shared if INITCONDs are matching +select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one); + +-- Varying INITCONDs should cause the states not to be shared. +select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one); + +rollback; + +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; + +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; + +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); + +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); + +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); + +rollback;