diff --git a/src/backend/executor/execQual.c b/src/backend/executor/execQual.c index 16bc8fa5f6..29f058ce5c 100644 --- a/src/backend/executor/execQual.c +++ b/src/backend/executor/execQual.c @@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent) break; case T_Aggref: { - Aggref *aggref = (Aggref *) node; AggrefExprState *astate = makeNode(AggrefExprState); astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref; if (parent && IsA(parent, AggState)) { AggState *aggstate = (AggState *) parent; - int naggs; aggstate->aggs = lcons(astate, aggstate->aggs); - naggs = ++aggstate->numaggs; - - astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs, - parent); - astate->args = (List *) ExecInitExpr((Expr *) aggref->args, - parent); - astate->aggfilter = ExecInitExpr(aggref->aggfilter, - parent); - - /* - * Complain if the aggregate's arguments contain any - * aggregates; nested agg functions are semantically - * nonsensical. (This should have been caught earlier, - * but we defend against it here anyway.) - */ - if (naggs != aggstate->numaggs) - ereport(ERROR, - (errcode(ERRCODE_GROUPING_ERROR), - errmsg("aggregate function calls cannot be nested"))); + aggstate->numaggs++; } else { diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 2bf48c54e3..2e3685557b 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -152,17 +152,28 @@ /* - * AggStatePerAggData - per-aggregate working state for the Agg scan + * AggStatePerTransData - per aggregate state value information + * + * Working state for updating the aggregate's state value, by calling the + * transition function with an input row. This struct does not store the + * information needed to produce the final aggregate result from the transition + * state, that's stored in AggStatePerAggData instead. This separation allows + * multiple aggregate results to be produced from a single state value. */ -typedef struct AggStatePerAggData +typedef struct AggStatePerTransData { /* * These values are set up during ExecInitAgg() and do not change * thereafter: */ - /* Links to Aggref expr and state nodes this working state is for */ - AggrefExprState *aggrefstate; + /* + * Link to an Aggref expr this state value is for. + * + * There can be multiple Aggref's sharing the same state value, as long as + * the inputs and transition function are identical. This points to the + * first one of them. + */ Aggref *aggref; /* @@ -186,25 +197,22 @@ typedef struct AggStatePerAggData */ int numTransInputs; - /* - * Number of arguments to pass to the finalfn. This is always at least 1 - * (the transition state value) plus any ordered-set direct args. If the - * finalfn wants extra args then we pass nulls corresponding to the - * aggregated input columns. - */ - int numFinalArgs; - - /* Oids of transfer functions */ + /* Oid of the state transition function */ Oid transfn_oid; - Oid finalfn_oid; /* may be InvalidOid */ + + /* Oid of state value's datatype */ + Oid aggtranstype; + + /* ExprStates of the FILTER and argument expressions. */ + ExprState *aggfilter; /* state of FILTER expression, if any */ + List *args; /* states of aggregated-argument expressions */ + List *aggdirectargs; /* states of direct-argument expressions */ /* - * fmgr lookup data for transfer functions --- only valid when - * corresponding oid is not InvalidOid. Note in particular that fn_strict - * flags are kept here. + * fmgr lookup data for transition function. Note in particular that the + * fn_strict flag is kept here. */ FmgrInfo transfn; - FmgrInfo finalfn; /* Input collation derived for aggregate */ Oid aggCollation; @@ -236,17 +244,15 @@ typedef struct AggStatePerAggData bool initValueIsNull; /* - * We need the len and byval info for the agg's input, result, and - * transition data types in order to know how to copy/delete values. + * We need the len and byval info for the agg's input and transition data + * types in order to know how to copy/delete values. * * Note that the info for the input type is used only when handling * DISTINCT aggs with just one argument, so there is only one input type. */ int16 inputtypeLen, - resulttypeLen, transtypeLen; bool inputtypeByVal, - resulttypeByVal, transtypeByVal; /* @@ -288,6 +294,54 @@ typedef struct AggStatePerAggData * worth the extra space consumption. */ FunctionCallInfoData transfn_fcinfo; +} AggStatePerTransData; + +/* + * AggStatePerAggData - per-aggregate information + * + * This contains the information needed to call the final function, to produce + * a final aggregate result from the state value. If there are multiple + * identical Aggrefs in the query, they can all share the same per-agg data. + * + * These values are set up during ExecInitAgg() and do not change thereafter. + */ +typedef struct AggStatePerAggData +{ + /* + * Link to an Aggref expr this state value is for. + * + * There can be multiple identical Aggref's sharing the same per-agg. This + * points to the first one of them. + */ + Aggref *aggref; + + /* index to the state value which this agg should use */ + int transno; + + /* Optional Oid of final function (may be InvalidOid) */ + Oid finalfn_oid; + + /* + * fmgr lookup data for final function --- only valid when finalfn_oid oid + * is not InvalidOid. + */ + FmgrInfo finalfn; + + /* + * Number of arguments to pass to the finalfn. This is always at least 1 + * (the transition state value) plus any ordered-set direct args. If the + * finalfn wants extra args then we pass nulls corresponding to the + * aggregated input columns. + */ + int numFinalArgs; + + /* + * We need the len and byval info for the agg's result data type in order + * to know how to copy/delete values. + */ + int16 resulttypeLen; + bool resulttypeByVal; + } AggStatePerAggData; /* @@ -358,25 +412,23 @@ typedef struct AggHashEntryData AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER]; } AggHashEntryData; - static void initialize_phase(AggState *aggstate, int newphase); static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, AggStatePerGroup pergroup, int numReset); static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull); static void prepare_projection_slot(AggState *aggstate, @@ -396,6 +448,17 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate); static void agg_fill_hash_table(AggState *aggstate); static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate); static Datum GetAggInitVal(Datum textInitVal, Oid transtype); +static void build_pertrans_for_aggref(AggStatePerTrans pertrans, + AggState *aggsate, EState *estate, + Aggref *aggref, Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + Oid *inputTypes, int numArguments); +static int find_compatible_peragg(Aggref *newagg, AggState *aggstate, + int lastaggno, List **same_input_transnos); +static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + List *possible_matches); /* @@ -498,20 +561,20 @@ fetch_input_tuple(AggState *aggstate) * When called, CurrentMemoryContext should be the per-query context. */ static void -initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, +initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { /* * Start a fresh sort operation for each DISTINCT/ORDER BY aggregate. */ - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { /* * In case of rescan, maybe there could be an uncompleted sort * operation? Clean it up if so. */ - if (peraggstate->sortstates[aggstate->current_set]) - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); + if (pertrans->sortstates[aggstate->current_set]) + tuplesort_end(pertrans->sortstates[aggstate->current_set]); /* @@ -519,21 +582,21 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * otherwise sort the full tuple. (See comments for * process_ordered_aggregate_single.) */ - if (peraggstate->numInputs == 1) - peraggstate->sortstates[aggstate->current_set] = - tuplesort_begin_datum(peraggstate->evaldesc->attrs[0]->atttypid, - peraggstate->sortOperators[0], - peraggstate->sortCollations[0], - peraggstate->sortNullsFirst[0], + if (pertrans->numInputs == 1) + pertrans->sortstates[aggstate->current_set] = + tuplesort_begin_datum(pertrans->evaldesc->attrs[0]->atttypid, + pertrans->sortOperators[0], + pertrans->sortCollations[0], + pertrans->sortNullsFirst[0], work_mem, false); else - peraggstate->sortstates[aggstate->current_set] = - tuplesort_begin_heap(peraggstate->evaldesc, - peraggstate->numSortCols, - peraggstate->sortColIdx, - peraggstate->sortOperators, - peraggstate->sortCollations, - peraggstate->sortNullsFirst, + pertrans->sortstates[aggstate->current_set] = + tuplesort_begin_heap(pertrans->evaldesc, + pertrans->numSortCols, + pertrans->sortColIdx, + pertrans->sortOperators, + pertrans->sortCollations, + pertrans->sortNullsFirst, work_mem, false); } @@ -543,20 +606,20 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * Note that when the initial value is pass-by-ref, we must copy it (into * the aggcontext) since we will pfree the transValue later. */ - if (peraggstate->initValueIsNull) - pergroupstate->transValue = peraggstate->initValue; + if (pertrans->initValueIsNull) + pergroupstate->transValue = pertrans->initValue; else { MemoryContext oldContext; oldContext = MemoryContextSwitchTo( aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); - pergroupstate->transValue = datumCopy(peraggstate->initValue, - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pergroupstate->transValue = datumCopy(pertrans->initValue, + pertrans->transtypeByVal, + pertrans->transtypeLen); MemoryContextSwitchTo(oldContext); } - pergroupstate->transValueIsNull = peraggstate->initValueIsNull; + pergroupstate->transValueIsNull = pertrans->initValueIsNull; /* * If the initial value for the transition state doesn't exist in the @@ -565,11 +628,11 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, * aggregates like max() and min().) The noTransValue flag signals that we * still need to do this. */ - pergroupstate->noTransValue = peraggstate->initValueIsNull; + pergroupstate->noTransValue = pertrans->initValueIsNull; } /* - * Initialize all aggregates for a new group of input values. + * Initialize all aggregate transition states for a new group of input values. * * If there are multiple grouping sets, we initialize only the first numReset * of them (the grouping sets are ordered so that the most specific one, which @@ -580,61 +643,61 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, */ static void initialize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, AggStatePerGroup pergroup, int numReset) { - int aggno; + int transno; int numGroupingSets = Max(aggstate->phase->numsets, 1); int setno = 0; + AggStatePerTrans transstates = aggstate->pertrans; if (numReset < 1) numReset = numGroupingSets; - for (aggno = 0; aggno < aggstate->numaggs; aggno++) + for (transno = 0; transno < aggstate->numtrans; transno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerTrans pertrans = &transstates[transno]; for (setno = 0; setno < numReset; setno++) { AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))]; + pergroupstate = &pergroup[transno + (setno * (aggstate->numtrans))]; aggstate->current_set = setno; - initialize_aggregate(aggstate, peraggstate, pergroupstate); + initialize_aggregate(aggstate, pertrans, pergroupstate); } } } /* * Given new input value(s), advance the transition function of one aggregate - * within one grouping set only (already set in aggstate->current_set) + * state within one grouping set only (already set in aggstate->current_set) * * The new values (and null flags) have been preloaded into argument positions - * 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again - * to pass to the transition function. We also expect that the static fields - * of the fcinfo are already initialized; that was done by ExecInitAgg(). + * 1 and up in pertrans->transfn_fcinfo, so that we needn't copy them again to + * pass to the transition function. We also expect that the static fields of + * the fcinfo are already initialized; that was done by ExecInitAgg(). * * It doesn't matter which memory context this is called in. */ static void advance_transition_function(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; MemoryContext oldContext; Datum newVal; - if (peraggstate->transfn.fn_strict) + if (pertrans->transfn.fn_strict) { /* * For a strict transfn, nothing happens when there's a NULL input; we * just keep the prior transValue. */ - int numTransInputs = peraggstate->numTransInputs; + int numTransInputs = pertrans->numTransInputs; int i; for (i = 1; i <= numTransInputs; i++) @@ -656,8 +719,8 @@ advance_transition_function(AggState *aggstate, oldContext = MemoryContextSwitchTo( aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); pergroupstate->transValue = datumCopy(fcinfo->arg[1], - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pertrans->transtypeByVal, + pertrans->transtypeLen); pergroupstate->transValueIsNull = false; pergroupstate->noTransValue = false; MemoryContextSwitchTo(oldContext); @@ -678,8 +741,8 @@ advance_transition_function(AggState *aggstate, /* We run the transition functions in per-input-tuple memory context */ oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; /* * OK to call the transition function @@ -690,22 +753,22 @@ advance_transition_function(AggState *aggstate, newVal = FunctionCallInvoke(fcinfo); - aggstate->curperagg = NULL; + aggstate->curpertrans = NULL; /* * If pass-by-ref datatype, must copy the new value into aggcontext and * pfree the prior transValue. But if transfn returned a pointer to its * first input, we don't need to do anything. */ - if (!peraggstate->transtypeByVal && + if (!pertrans->transtypeByVal && DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue)) { if (!fcinfo->isnull) { MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); newVal = datumCopy(newVal, - peraggstate->transtypeByVal, - peraggstate->transtypeLen); + pertrans->transtypeByVal, + pertrans->transtypeLen); } if (!pergroupstate->transValueIsNull) pfree(DatumGetPointer(pergroupstate->transValue)); @@ -718,26 +781,26 @@ advance_transition_function(AggState *aggstate, } /* - * Advance all the aggregates for one input tuple. The input tuple - * has been stored in tmpcontext->ecxt_outertuple, so that it is accessible - * to ExecEvalExpr. pergroup is the array of per-group structs to use - * (this might be in a hashtable entry). + * Advance each aggregate transition state for one input tuple. The input + * tuple has been stored in tmpcontext->ecxt_outertuple, so that it is + * accessible to ExecEvalExpr. pergroup is the array of per-group structs to + * use (this might be in a hashtable entry). * * When called, CurrentMemoryContext should be the per-query context. */ static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) { - int aggno; + int transno; int setno = 0; int numGroupingSets = Max(aggstate->phase->numsets, 1); - int numAggs = aggstate->numaggs; + int numTrans = aggstate->numtrans; - for (aggno = 0; aggno < numAggs; aggno++) + for (transno = 0; transno < numTrans; transno++) { - AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; - ExprState *filter = peraggstate->aggrefstate->aggfilter; - int numTransInputs = peraggstate->numTransInputs; + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + ExprState *filter = pertrans->aggfilter; + int numTransInputs = pertrans->numTransInputs; int i; TupleTableSlot *slot; @@ -754,12 +817,12 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) } /* Evaluate the current input expressions for this aggregate */ - slot = ExecProject(peraggstate->evalproj, NULL); + slot = ExecProject(pertrans->evalproj, NULL); - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { /* DISTINCT and/or ORDER BY case */ - Assert(slot->tts_nvalid == peraggstate->numInputs); + Assert(slot->tts_nvalid == pertrans->numInputs); /* * If the transfn is strict, we want to check for nullity before @@ -768,7 +831,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) * not numInputs, since nullity in columns used only for sorting * is not relevant here. */ - if (peraggstate->transfn.fn_strict) + if (pertrans->transfn.fn_strict) { for (i = 0; i < numTransInputs; i++) { @@ -782,18 +845,18 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) for (setno = 0; setno < numGroupingSets; setno++) { /* OK, put the tuple into the tuplesort object */ - if (peraggstate->numInputs == 1) - tuplesort_putdatum(peraggstate->sortstates[setno], + if (pertrans->numInputs == 1) + tuplesort_putdatum(pertrans->sortstates[setno], slot->tts_values[0], slot->tts_isnull[0]); else - tuplesort_puttupleslot(peraggstate->sortstates[setno], slot); + tuplesort_puttupleslot(pertrans->sortstates[setno], slot); } } else { /* We can apply the transition function immediately */ - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; /* Load values into fcinfo */ /* Start from 1, since the 0th arg will be the transition value */ @@ -806,11 +869,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)]; + AggStatePerGroup pergroupstate = &pergroup[transno + (setno * numTrans)]; aggstate->current_set = setno; - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); } } } @@ -841,7 +904,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) */ static void process_ordered_aggregate_single(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { Datum oldVal = (Datum) 0; @@ -849,14 +912,14 @@ process_ordered_aggregate_single(AggState *aggstate, bool haveOldVal = false; MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; MemoryContext oldContext; - bool isDistinct = (peraggstate->numDistinctCols > 0); - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; + bool isDistinct = (pertrans->numDistinctCols > 0); + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; Datum *newVal; bool *isNull; - Assert(peraggstate->numDistinctCols < 2); + Assert(pertrans->numDistinctCols < 2); - tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); + tuplesort_performsort(pertrans->sortstates[aggstate->current_set]); /* Load the column into argument 1 (arg 0 will be transition value) */ newVal = fcinfo->arg + 1; @@ -868,7 +931,7 @@ process_ordered_aggregate_single(AggState *aggstate, * pfree them when they are no longer needed. */ - while (tuplesort_getdatum(peraggstate->sortstates[aggstate->current_set], + while (tuplesort_getdatum(pertrans->sortstates[aggstate->current_set], true, newVal, isNull)) { /* @@ -887,18 +950,18 @@ process_ordered_aggregate_single(AggState *aggstate, haveOldVal && ((oldIsNull && *isNull) || (!oldIsNull && !*isNull && - DatumGetBool(FunctionCall2(&peraggstate->equalfns[0], + DatumGetBool(FunctionCall2(&pertrans->equalfns[0], oldVal, *newVal))))) { /* equal to prior, so forget this one */ - if (!peraggstate->inputtypeByVal && !*isNull) + if (!pertrans->inputtypeByVal && !*isNull) pfree(DatumGetPointer(*newVal)); } else { - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); /* forget the old value, if any */ - if (!oldIsNull && !peraggstate->inputtypeByVal) + if (!oldIsNull && !pertrans->inputtypeByVal) pfree(DatumGetPointer(oldVal)); /* and remember the new one for subsequent equality checks */ oldVal = *newVal; @@ -909,11 +972,11 @@ process_ordered_aggregate_single(AggState *aggstate, MemoryContextSwitchTo(oldContext); } - if (!oldIsNull && !peraggstate->inputtypeByVal) + if (!oldIsNull && !pertrans->inputtypeByVal) pfree(DatumGetPointer(oldVal)); - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); - peraggstate->sortstates[aggstate->current_set] = NULL; + tuplesort_end(pertrans->sortstates[aggstate->current_set]); + pertrans->sortstates[aggstate->current_set] = NULL; } /* @@ -930,25 +993,25 @@ process_ordered_aggregate_single(AggState *aggstate, */ static void process_ordered_aggregate_multi(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerTrans pertrans, AggStatePerGroup pergroupstate) { MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; - FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; - TupleTableSlot *slot1 = peraggstate->evalslot; - TupleTableSlot *slot2 = peraggstate->uniqslot; - int numTransInputs = peraggstate->numTransInputs; - int numDistinctCols = peraggstate->numDistinctCols; + FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo; + TupleTableSlot *slot1 = pertrans->evalslot; + TupleTableSlot *slot2 = pertrans->uniqslot; + int numTransInputs = pertrans->numTransInputs; + int numDistinctCols = pertrans->numDistinctCols; bool haveOldValue = false; int i; - tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); + tuplesort_performsort(pertrans->sortstates[aggstate->current_set]); ExecClearTuple(slot1); if (slot2) ExecClearTuple(slot2); - while (tuplesort_gettupleslot(peraggstate->sortstates[aggstate->current_set], + while (tuplesort_gettupleslot(pertrans->sortstates[aggstate->current_set], true, slot1)) { /* @@ -962,8 +1025,8 @@ process_ordered_aggregate_multi(AggState *aggstate, !haveOldValue || !execTuplesMatch(slot1, slot2, numDistinctCols, - peraggstate->sortColIdx, - peraggstate->equalfns, + pertrans->sortColIdx, + pertrans->equalfns, workcontext)) { /* Load values into fcinfo */ @@ -974,7 +1037,7 @@ process_ordered_aggregate_multi(AggState *aggstate, fcinfo->argnull[i + 1] = slot1->tts_isnull[i]; } - advance_transition_function(aggstate, peraggstate, pergroupstate); + advance_transition_function(aggstate, pertrans, pergroupstate); if (numDistinctCols > 0) { @@ -997,8 +1060,8 @@ process_ordered_aggregate_multi(AggState *aggstate, if (slot2) ExecClearTuple(slot2); - tuplesort_end(peraggstate->sortstates[aggstate->current_set]); - peraggstate->sortstates[aggstate->current_set] = NULL; + tuplesort_end(pertrans->sortstates[aggstate->current_set]); + pertrans->sortstates[aggstate->current_set] = NULL; } /* @@ -1009,10 +1072,14 @@ process_ordered_aggregate_multi(AggState *aggstate, * * The finalfunction will be run, and the result delivered, in the * output-tuple context; caller's CurrentMemoryContext does not matter. + * + * The finalfn uses the state as set in the transno. This also might be + * being used by another aggregate function, so it's important that we do + * nothing destructive here. */ static void finalize_aggregate(AggState *aggstate, - AggStatePerAgg peraggstate, + AggStatePerAgg peragg, AggStatePerGroup pergroupstate, Datum *resultVal, bool *resultIsNull) { @@ -1021,6 +1088,7 @@ finalize_aggregate(AggState *aggstate, MemoryContext oldContext; int i; ListCell *lc; + AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno]; oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); @@ -1031,7 +1099,7 @@ finalize_aggregate(AggState *aggstate, * for the transition state value. */ i = 1; - foreach(lc, peraggstate->aggrefstate->aggdirectargs) + foreach(lc, pertrans->aggdirectargs) { ExprState *expr = (ExprState *) lfirst(lc); @@ -1046,16 +1114,16 @@ finalize_aggregate(AggState *aggstate, /* * Apply the agg's finalfn if one is provided, else return transValue. */ - if (OidIsValid(peraggstate->finalfn_oid)) + if (OidIsValid(peragg->finalfn_oid)) { - int numFinalArgs = peraggstate->numFinalArgs; + int numFinalArgs = peragg->numFinalArgs; - /* set up aggstate->curperagg for AggGetAggref() */ - aggstate->curperagg = peraggstate; + /* set up aggstate->curpertrans for AggGetAggref() */ + aggstate->curpertrans = pertrans; - InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn, + InitFunctionCallInfoData(fcinfo, &peragg->finalfn, numFinalArgs, - peraggstate->aggCollation, + pertrans->aggCollation, (void *) aggstate, NULL); /* Fill in the transition state value */ @@ -1082,7 +1150,7 @@ finalize_aggregate(AggState *aggstate, *resultVal = FunctionCallInvoke(&fcinfo); *resultIsNull = fcinfo.isnull; } - aggstate->curperagg = NULL; + aggstate->curpertrans = NULL; } else { @@ -1093,12 +1161,12 @@ finalize_aggregate(AggState *aggstate, /* * If result is pass-by-ref, make sure it is in the right context. */ - if (!peraggstate->resulttypeByVal && !*resultIsNull && + if (!peragg->resulttypeByVal && !*resultIsNull && !MemoryContextContains(CurrentMemoryContext, DatumGetPointer(*resultVal))) *resultVal = datumCopy(*resultVal, - peraggstate->resulttypeByVal, - peraggstate->resulttypeLen); + peragg->resulttypeByVal, + peragg->resulttypeLen); MemoryContextSwitchTo(oldContext); } @@ -1173,7 +1241,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet */ static void finalize_aggregates(AggState *aggstate, - AggStatePerAgg peragg, + AggStatePerAgg peraggs, AggStatePerGroup pergroup, int currentSet) { @@ -1189,26 +1257,28 @@ finalize_aggregates(AggState *aggstate, for (aggno = 0; aggno < aggstate->numaggs; aggno++) { - AggStatePerAgg peraggstate = &peragg[aggno]; + AggStatePerAgg peragg = &peraggs[aggno]; + int transno = peragg->transno; + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; AggStatePerGroup pergroupstate; - pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))]; + pergroupstate = &pergroup[transno + (currentSet * (aggstate->numtrans))]; - if (peraggstate->numSortCols > 0) + if (pertrans->numSortCols > 0) { Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED); - if (peraggstate->numInputs == 1) + if (pertrans->numInputs == 1) process_ordered_aggregate_single(aggstate, - peraggstate, + pertrans, pergroupstate); else process_ordered_aggregate_multi(aggstate, - peraggstate, + pertrans, pergroupstate); } - finalize_aggregate(aggstate, peraggstate, pergroupstate, + finalize_aggregate(aggstate, peragg, pergroupstate, &aggvalues[aggno], &aggnulls[aggno]); } } @@ -1428,7 +1498,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot) if (isnew) { /* initialize aggregates for new tuple group */ - initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0); + initialize_aggregates(aggstate, entry->pergroup, 0); } return entry; @@ -1716,7 +1786,7 @@ agg_retrieve_direct(AggState *aggstate) /* * Initialize working state for a new input tuple group. */ - initialize_aggregates(aggstate, peragg, pergroup, numReset); + initialize_aggregates(aggstate, pergroup, numReset); if (aggstate->grp_firstTuple != NULL) { @@ -1945,17 +2015,18 @@ AggState * ExecInitAgg(Agg *node, EState *estate, int eflags) { AggState *aggstate; - AggStatePerAgg peragg; + AggStatePerAgg peraggs; + AggStatePerTrans pertransstates; Plan *outerPlan; ExprContext *econtext; int numaggs, + transno, aggno; int phase; ListCell *l; Bitmapset *all_grouped_cols = NULL; int numGroupingSets = 1; int numPhases; - int currentsortno = 0; int i = 0; int j = 0; @@ -1971,12 +2042,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->aggs = NIL; aggstate->numaggs = 0; + aggstate->numtrans = 0; aggstate->maxsets = 0; aggstate->hashfunctions = NULL; aggstate->projected_set = -1; aggstate->current_set = 0; aggstate->peragg = NULL; - aggstate->curperagg = NULL; + aggstate->pertrans = NULL; + aggstate->curpertrans = NULL; aggstate->agg_done = false; aggstate->input_done = false; aggstate->pergroup = NULL; @@ -2209,8 +2282,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs); econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs); - peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); - aggstate->peragg = peragg; + peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); + pertransstates = (AggStatePerTrans) palloc0(sizeof(AggStatePerTransData) * numaggs); + + aggstate->peragg = peraggs; + aggstate->pertrans = pertransstates; if (node->aggstrategy == AGG_HASHED) { @@ -2230,71 +2306,86 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->pergroup = pergroup; } - /* + /* ----------------- * Perform lookups of aggregate function info, and initialize the - * unchanging fields of the per-agg data. We also detect duplicate - * aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When - * duplicates are detected, we only make an AggStatePerAgg struct for the - * first one. The clones are simply pointed at the same result entry by - * giving them duplicate aggno values. + * unchanging fields of the per-agg and per-trans data. + * + * We try to optimize by detecting duplicate aggregate functions so that + * their state and final values are re-used, rather than needlessly being + * re-calculated independently. We also detect aggregates that are not + * the same, but which can share the same transition state. + * + * Scenarios: + * + * 1. An aggregate function appears more than once in query: + * + * SELECT SUM(x) FROM ... HAVING SUM(x) > 0 + * + * Since the aggregates are the identical, we only need to calculate + * the calculate it once. Both aggregates will share the same 'aggno' + * value. + * + * 2. Two different aggregate functions appear in the query, but the + * aggregates have the same transition function and initial value, but + * different final function: + * + * SELECT SUM(x), AVG(x) FROM ... + * + * In this case we must create a new peragg for the varying aggregate, + * and need to call the final functions separately, but can share the + * same transition state. + * + * For either of these optimizations to be valid, the aggregate's + * arguments must be the same, including any modifiers such as ORDER BY, + * DISTINCT and FILTER, and they mustn't contain any volatile functions. + * ----------------- */ aggno = -1; + transno = -1; foreach(l, aggstate->aggs) { AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l); Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr; - AggStatePerAgg peraggstate; + AggStatePerAgg peragg; + AggStatePerTrans pertrans; + int existing_aggno; + int existing_transno; + List *same_input_transnos; Oid inputTypes[FUNC_MAX_ARGS]; int numArguments; int numDirectArgs; - int numInputs; - int numSortCols; - int numDistinctCols; - List *sortlist; HeapTuple aggTuple; Form_pg_aggregate aggform; - Oid aggtranstype; AclResult aclresult; Oid transfn_oid, finalfn_oid; - Expr *transfnexpr, - *finalfnexpr; + Expr *finalfnexpr; + Oid aggtranstype; Datum textInitVal; - int i; - ListCell *lc; + Datum initValue; + bool initValueIsNull; /* Planner should have assigned aggregate to correct level */ Assert(aggref->agglevelsup == 0); - /* Look for a previous duplicate aggregate */ - for (i = 0; i <= aggno; i++) + /* 1. Check for already processed aggs which can be re-used */ + existing_aggno = find_compatible_peragg(aggref, aggstate, aggno, + &same_input_transnos); + if (existing_aggno != -1) { - if (equal(aggref, peragg[i].aggref) && - !contain_volatile_functions((Node *) aggref)) - break; - } - if (i <= aggno) - { - /* Found a match to an existing entry, so just mark it */ - aggrefstate->aggno = i; + /* + * Existing compatible agg found. so just point the Aggref to the + * same per-agg struct. + */ + aggrefstate->aggno = existing_aggno; continue; } - /* Nope, so assign a new PerAgg record */ - peraggstate = &peragg[++aggno]; - /* Mark Aggref state node with assigned index in the result array */ + peragg = &peraggs[++aggno]; + peragg->aggref = aggref; aggrefstate->aggno = aggno; - /* Begin filling in the peraggstate data */ - peraggstate->aggrefstate = aggrefstate; - peraggstate->aggref = aggref; - peraggstate->sortstates = (Tuplesortstate **) - palloc0(sizeof(Tuplesortstate *) * numGroupingSets); - - for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++) - peraggstate->sortstates[currentsortno] = NULL; - /* Fetch the pg_aggregate row */ aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(aggref->aggfnoid)); @@ -2311,8 +2402,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) get_func_name(aggref->aggfnoid)); InvokeFunctionExecuteHook(aggref->aggfnoid); - peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn; - peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn; + transfn_oid = aggform->aggtransfn; + peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn; /* Check that aggregate owner has permission to call component fns */ { @@ -2350,74 +2441,43 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) * agg accepts ANY or a polymorphic type. */ numArguments = get_aggregate_argtypes(aggref, inputTypes); - peraggstate->numArguments = numArguments; /* Count the "direct" arguments, if any */ numDirectArgs = list_length(aggref->aggdirectargs); - /* Count the number of aggregated input columns */ - numInputs = list_length(aggref->args); - peraggstate->numInputs = numInputs; - - /* Detect how many arguments to pass to the transfn */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) - peraggstate->numTransInputs = numInputs; - else - peraggstate->numTransInputs = numArguments; - - /* Detect how many arguments to pass to the finalfn */ - if (aggform->aggfinalextra) - peraggstate->numFinalArgs = numArguments + 1; - else - peraggstate->numFinalArgs = numDirectArgs + 1; - /* resolve actual type of transition state, if polymorphic */ aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, aggform->aggtranstype, inputTypes, numArguments); - /* build expression trees using actual argument & result types */ - build_aggregate_fnexprs(inputTypes, - numArguments, - numDirectArgs, - peraggstate->numFinalArgs, - aggref->aggvariadic, - aggtranstype, - aggref->aggtype, - aggref->inputcollid, - transfn_oid, - InvalidOid, /* invtrans is not needed here */ - finalfn_oid, - &transfnexpr, - NULL, - &finalfnexpr); - - /* set up infrastructure for calling the transfn and finalfn */ - fmgr_info(transfn_oid, &peraggstate->transfn); - fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn); + /* Detect how many arguments to pass to the finalfn */ + if (aggform->aggfinalextra) + peragg->numFinalArgs = numArguments + 1; + else + peragg->numFinalArgs = numDirectArgs + 1; + /* + * build expression trees using actual argument & result types for the + * finalfn, if it exists + */ if (OidIsValid(finalfn_oid)) { - fmgr_info(finalfn_oid, &peraggstate->finalfn); - fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); + build_aggregate_finalfn_expr(inputTypes, + peragg->numFinalArgs, + aggtranstype, + aggref->aggtype, + aggref->inputcollid, + finalfn_oid, + &finalfnexpr); + fmgr_info(finalfn_oid, &peragg->finalfn); + fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn); } - peraggstate->aggCollation = aggref->inputcollid; - - InitFunctionCallInfoData(peraggstate->transfn_fcinfo, - &peraggstate->transfn, - peraggstate->numTransInputs + 1, - peraggstate->aggCollation, - (void *) aggstate, NULL); - - /* get info about relevant datatypes */ + /* get info about the result type's datatype */ get_typlenbyval(aggref->aggtype, - &peraggstate->resulttypeLen, - &peraggstate->resulttypeByVal); - get_typlenbyval(aggtranstype, - &peraggstate->transtypeLen, - &peraggstate->transtypeByVal); + &peragg->resulttypeLen, + &peragg->resulttypeByVal); /* * initval is potentially null, so don't try to access it as a struct @@ -2425,161 +2485,292 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) */ textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, Anum_pg_aggregate_agginitval, - &peraggstate->initValueIsNull); - - if (peraggstate->initValueIsNull) - peraggstate->initValue = (Datum) 0; + &initValueIsNull); + if (initValueIsNull) + initValue = (Datum) 0; else - peraggstate->initValue = GetAggInitVal(textInitVal, - aggtranstype); + initValue = GetAggInitVal(textInitVal, aggtranstype); /* - * If the transfn is strict and the initval is NULL, make sure input - * type and transtype are the same (or at least binary-compatible), so - * that it's OK to use the first aggregated input value as the initial - * transValue. This should have been checked at agg definition time, - * but we must check again in case the transfn's strictness property - * has been changed. - */ - if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull) - { - if (numArguments <= numDirectArgs || - !IsBinaryCoercible(inputTypes[numDirectArgs], aggtranstype)) - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("aggregate %u needs to have compatible input type and transition type", - aggref->aggfnoid))); - } - - /* - * Get a tupledesc corresponding to the aggregated inputs (including - * sort expressions) of the agg. - */ - peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false); - - /* Create slot we're going to do argument evaluation in */ - peraggstate->evalslot = ExecInitExtraTupleSlot(estate); - ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc); - - /* Set up projection info for evaluation */ - peraggstate->evalproj = ExecBuildProjectionInfo(aggrefstate->args, - aggstate->tmpcontext, - peraggstate->evalslot, - NULL); - - /* - * If we're doing either DISTINCT or ORDER BY for a plain agg, then we - * have a list of SortGroupClause nodes; fish out the data in them and - * stick them into arrays. We ignore ORDER BY for an ordered-set agg, - * however; the agg's transfn and finalfn are responsible for that. + * 2. Build working state for invoking the transition function, or + * look up previously initialized working state, if we can share it. * - * Note that by construction, if there is a DISTINCT clause then the - * ORDER BY clause is a prefix of it (see transformDistinctClause). + * find_compatible_peragg() already collected a list of per-Trans's + * with the same inputs. Check if any of them have the same transition + * function and initial value. */ - if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + existing_transno = find_compatible_pertrans(aggstate, aggref, + transfn_oid, aggtranstype, + initValue, initValueIsNull, + same_input_transnos); + if (existing_transno != -1) { - sortlist = NIL; - numSortCols = numDistinctCols = 0; - } - else if (aggref->aggdistinct) - { - sortlist = aggref->aggdistinct; - numSortCols = numDistinctCols = list_length(sortlist); - Assert(numSortCols >= list_length(aggref->aggorder)); + /* + * Existing compatible trans found, so just point the 'peragg' to + * the same per-trans struct. + */ + pertrans = &pertransstates[existing_transno]; + peragg->transno = existing_transno; } else { - sortlist = aggref->aggorder; - numSortCols = list_length(sortlist); - numDistinctCols = 0; + pertrans = &pertransstates[++transno]; + build_pertrans_for_aggref(pertrans, aggstate, estate, + aggref, transfn_oid, aggtranstype, + initValue, initValueIsNull, + inputTypes, numArguments); + peragg->transno = transno; } - - peraggstate->numSortCols = numSortCols; - peraggstate->numDistinctCols = numDistinctCols; - - if (numSortCols > 0) - { - /* - * We don't implement DISTINCT or ORDER BY aggs in the HASHED case - * (yet) - */ - Assert(node->aggstrategy != AGG_HASHED); - - /* If we have only one input, we need its len/byval info. */ - if (numInputs == 1) - { - get_typlenbyval(inputTypes[numDirectArgs], - &peraggstate->inputtypeLen, - &peraggstate->inputtypeByVal); - } - else if (numDistinctCols > 0) - { - /* we will need an extra slot to store prior values */ - peraggstate->uniqslot = ExecInitExtraTupleSlot(estate); - ExecSetSlotDescriptor(peraggstate->uniqslot, - peraggstate->evaldesc); - } - - /* Extract the sort information for use later */ - peraggstate->sortColIdx = - (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber)); - peraggstate->sortOperators = - (Oid *) palloc(numSortCols * sizeof(Oid)); - peraggstate->sortCollations = - (Oid *) palloc(numSortCols * sizeof(Oid)); - peraggstate->sortNullsFirst = - (bool *) palloc(numSortCols * sizeof(bool)); - - i = 0; - foreach(lc, sortlist) - { - SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); - TargetEntry *tle = get_sortgroupclause_tle(sortcl, - aggref->args); - - /* the parser should have made sure of this */ - Assert(OidIsValid(sortcl->sortop)); - - peraggstate->sortColIdx[i] = tle->resno; - peraggstate->sortOperators[i] = sortcl->sortop; - peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr); - peraggstate->sortNullsFirst[i] = sortcl->nulls_first; - i++; - } - Assert(i == numSortCols); - } - - if (aggref->aggdistinct) - { - Assert(numArguments > 0); - - /* - * We need the equal function for each DISTINCT comparison we will - * make. - */ - peraggstate->equalfns = - (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo)); - - i = 0; - foreach(lc, aggref->aggdistinct) - { - SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); - - fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]); - i++; - } - Assert(i == numDistinctCols); - } - ReleaseSysCache(aggTuple); } - /* Update numaggs to match number of unique aggregates found */ + /* + * Update numaggs to match the number of unique aggregates found. Also set + * numstates to the number of unique aggregate states found. + */ aggstate->numaggs = aggno + 1; + aggstate->numtrans = transno + 1; return aggstate; } +/* + * Build the state needed to calculate a state value for an aggregate. + * + * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate + * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest + * of the arguments could be calculated from 'aggref', but the caller has + * calculated them already, so might as well pass them. + */ +static void +build_pertrans_for_aggref(AggStatePerTrans pertrans, + AggState *aggstate, EState *estate, + Aggref *aggref, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + Oid *inputTypes, int numArguments) +{ + int numGroupingSets = Max(aggstate->maxsets, 1); + Expr *transfnexpr; + ListCell *lc; + int numInputs; + int numDirectArgs; + List *sortlist; + int numSortCols; + int numDistinctCols; + int naggs; + int i; + + /* Begin filling in the pertrans data */ + pertrans->aggref = aggref; + pertrans->aggCollation = aggref->inputcollid; + pertrans->transfn_oid = aggtransfn; + pertrans->initValue = initValue; + pertrans->initValueIsNull = initValueIsNull; + + /* Count the "direct" arguments, if any */ + numDirectArgs = list_length(aggref->aggdirectargs); + + /* Count the number of aggregated input columns */ + pertrans->numInputs = numInputs = list_length(aggref->args); + + pertrans->aggtranstype = aggtranstype; + + /* Detect how many arguments to pass to the transfn */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + pertrans->numTransInputs = numInputs; + else + pertrans->numTransInputs = numArguments; + + /* + * Set up infrastructure for calling the transfn + */ + build_aggregate_transfn_expr(inputTypes, + numArguments, + numDirectArgs, + aggref->aggvariadic, + aggtranstype, + aggref->inputcollid, + aggtransfn, + InvalidOid, /* invtrans is not needed here */ + &transfnexpr, + NULL); + fmgr_info(aggtransfn, &pertrans->transfn); + fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn); + + InitFunctionCallInfoData(pertrans->transfn_fcinfo, + &pertrans->transfn, + pertrans->numTransInputs + 1, + pertrans->aggCollation, + (void *) aggstate, NULL); + + /* + * If the transfn is strict and the initval is NULL, make sure input type + * and transtype are the same (or at least binary-compatible), so that + * it's OK to use the first aggregated input value as the initial + * transValue. This should have been checked at agg definition time, but + * we must check again in case the transfn's strictness property has been + * changed. + */ + if (pertrans->transfn.fn_strict && pertrans->initValueIsNull) + { + if (numArguments <= numDirectArgs || + !IsBinaryCoercible(inputTypes[numDirectArgs], + aggtranstype)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate %u needs to have compatible input type and transition type", + aggref->aggfnoid))); + } + + /* get info about the state value's datatype */ + get_typlenbyval(aggtranstype, + &pertrans->transtypeLen, + &pertrans->transtypeByVal); + + /* + * Get a tupledesc corresponding to the aggregated inputs (including sort + * expressions) of the agg. + */ + pertrans->evaldesc = ExecTypeFromTL(aggref->args, false); + + /* Create slot we're going to do argument evaluation in */ + pertrans->evalslot = ExecInitExtraTupleSlot(estate); + ExecSetSlotDescriptor(pertrans->evalslot, pertrans->evaldesc); + + /* Initialize the input and FILTER expressions */ + naggs = aggstate->numaggs; + pertrans->aggfilter = ExecInitExpr(aggref->aggfilter, + (PlanState *) aggstate); + pertrans->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs, + (PlanState *) aggstate); + pertrans->args = (List *) ExecInitExpr((Expr *) aggref->args, + (PlanState *) aggstate); + + /* + * Complain if the aggregate's arguments contain any aggregates; nested + * agg functions are semantically nonsensical. (This should have been + * caught earlier, but we defend against it here anyway.) + */ + if (naggs != aggstate->numaggs) + ereport(ERROR, + (errcode(ERRCODE_GROUPING_ERROR), + errmsg("aggregate function calls cannot be nested"))); + + /* Set up projection info for evaluation */ + pertrans->evalproj = ExecBuildProjectionInfo(pertrans->args, + aggstate->tmpcontext, + pertrans->evalslot, + NULL); + + /* + * If we're doing either DISTINCT or ORDER BY for a plain agg, then we + * have a list of SortGroupClause nodes; fish out the data in them and + * stick them into arrays. We ignore ORDER BY for an ordered-set agg, + * however; the agg's transfn and finalfn are responsible for that. + * + * Note that by construction, if there is a DISTINCT clause then the ORDER + * BY clause is a prefix of it (see transformDistinctClause). + */ + if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) + { + sortlist = NIL; + numSortCols = numDistinctCols = 0; + } + else if (aggref->aggdistinct) + { + sortlist = aggref->aggdistinct; + numSortCols = numDistinctCols = list_length(sortlist); + Assert(numSortCols >= list_length(aggref->aggorder)); + } + else + { + sortlist = aggref->aggorder; + numSortCols = list_length(sortlist); + numDistinctCols = 0; + } + + pertrans->numSortCols = numSortCols; + pertrans->numDistinctCols = numDistinctCols; + + if (numSortCols > 0) + { + /* + * We don't implement DISTINCT or ORDER BY aggs in the HASHED case + * (yet) + */ + Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED); + + /* If we have only one input, we need its len/byval info. */ + if (numInputs == 1) + { + get_typlenbyval(inputTypes[numDirectArgs], + &pertrans->inputtypeLen, + &pertrans->inputtypeByVal); + } + else if (numDistinctCols > 0) + { + /* we will need an extra slot to store prior values */ + pertrans->uniqslot = ExecInitExtraTupleSlot(estate); + ExecSetSlotDescriptor(pertrans->uniqslot, + pertrans->evaldesc); + } + + /* Extract the sort information for use later */ + pertrans->sortColIdx = + (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber)); + pertrans->sortOperators = + (Oid *) palloc(numSortCols * sizeof(Oid)); + pertrans->sortCollations = + (Oid *) palloc(numSortCols * sizeof(Oid)); + pertrans->sortNullsFirst = + (bool *) palloc(numSortCols * sizeof(bool)); + + i = 0; + foreach(lc, sortlist) + { + SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); + TargetEntry *tle = get_sortgroupclause_tle(sortcl, aggref->args); + + /* the parser should have made sure of this */ + Assert(OidIsValid(sortcl->sortop)); + + pertrans->sortColIdx[i] = tle->resno; + pertrans->sortOperators[i] = sortcl->sortop; + pertrans->sortCollations[i] = exprCollation((Node *) tle->expr); + pertrans->sortNullsFirst[i] = sortcl->nulls_first; + i++; + } + Assert(i == numSortCols); + } + + if (aggref->aggdistinct) + { + Assert(numArguments > 0); + + /* + * We need the equal function for each DISTINCT comparison we will + * make. + */ + pertrans->equalfns = + (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo)); + + i = 0; + foreach(lc, aggref->aggdistinct) + { + SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); + + fmgr_info(get_opcode(sortcl->eqop), &pertrans->equalfns[i]); + i++; + } + Assert(i == numDistinctCols); + } + + pertrans->sortstates = (Tuplesortstate **) + palloc0(sizeof(Tuplesortstate *) * numGroupingSets); +} + + static Datum GetAggInitVal(Datum textInitVal, Oid transtype) { @@ -2596,11 +2787,130 @@ GetAggInitVal(Datum textInitVal, Oid transtype) return initVal; } +/* + * find_compatible_peragg - search for a previously initialized per-Agg struct + * + * Searches the previously looked at aggregates to find one which is compatible + * with this one, with the same input parameters. If no compatible aggregate + * can be found, returns -1. + * + * As a side-effect, this also collects a list of existing per-Trans structs + * with matching inputs. If no identical Aggref is found, the list is passed + * later to find_compatible_perstate, to see if we can at least reuse the + * state value of another aggregate. + */ +static int +find_compatible_peragg(Aggref *newagg, AggState *aggstate, + int lastaggno, List **same_input_transnos) +{ + int aggno; + AggStatePerAgg peraggs; + + *same_input_transnos = NIL; + + /* we mustn't reuse the aggref if it contains volatile function calls */ + if (contain_volatile_functions((Node *) newagg)) + return -1; + + peraggs = aggstate->peragg; + + /* + * Search through the list of already seen aggregates. If we find an + * existing aggregate with the same aggregate function and input + * parameters as an existing one, then we can re-use that one. While + * searching, we'll also collect a list of Aggrefs with the same input + * parameters. If no matching Aggref is found, the caller can potentially + * still re-use the transition state of one of them. + */ + for (aggno = 0; aggno <= lastaggno; aggno++) + { + AggStatePerAgg peragg; + Aggref *existingRef; + + peragg = &peraggs[aggno]; + existingRef = peragg->aggref; + + /* all of the following must be the same or it's no match */ + if (newagg->inputcollid != existingRef->inputcollid || + newagg->aggstar != existingRef->aggstar || + newagg->aggvariadic != existingRef->aggvariadic || + newagg->aggkind != existingRef->aggkind || + !equal(newagg->aggdirectargs, existingRef->aggdirectargs) || + !equal(newagg->args, existingRef->args) || + !equal(newagg->aggorder, existingRef->aggorder) || + !equal(newagg->aggdistinct, existingRef->aggdistinct) || + !equal(newagg->aggfilter, existingRef->aggfilter)) + continue; + + /* if it's the same aggregate function then report exact match */ + if (newagg->aggfnoid == existingRef->aggfnoid && + newagg->aggtype == existingRef->aggtype && + newagg->aggcollid == existingRef->aggcollid) + { + list_free(*same_input_transnos); + *same_input_transnos = NIL; + return aggno; + } + + /* + * Not identical, but it had the same inputs. Return it to the caller, + * in case we can re-use its per-trans state. + */ + *same_input_transnos = lappend_int(*same_input_transnos, + peragg->transno); + } + + return -1; +} + +/* + * find_compatible_pertrans - search for a previously initialized per-Trans + * struct + * + * Searches the list of transnos for a per-Trans struct with the same + * transition state and initial condition. (The inputs have already been + * verified to match.) + */ +static int +find_compatible_pertrans(AggState *aggstate, Aggref *newagg, + Oid aggtransfn, Oid aggtranstype, + Datum initValue, bool initValueIsNull, + List *transnos) +{ + ListCell *lc; + + foreach(lc, transnos) + { + int transno = lfirst_int(lc); + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + + /* + * if the transfns or transition state types are not the same then the + * state can't be shared. + */ + if (aggtransfn != pertrans->transfn_oid || + aggtranstype != pertrans->aggtranstype) + continue; + + /* Check that the initial condition matches, too. */ + if (initValueIsNull && pertrans->initValueIsNull) + return transno; + + if (!initValueIsNull && !pertrans->initValueIsNull && + datumIsEqual(initValue, pertrans->initValue, + pertrans->transtypeByVal, pertrans->transtypeLen)) + { + return transno; + } + } + return -1; +} + void ExecEndAgg(AggState *node) { PlanState *outerPlan; - int aggno; + int transno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2611,14 +2921,14 @@ ExecEndAgg(AggState *node) if (node->sort_out) tuplesort_end(node->sort_out); - for (aggno = 0; aggno < node->numaggs; aggno++) + for (transno = 0; transno < node->numtrans; transno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerTrans pertrans = &node->pertrans[transno]; for (setno = 0; setno < numGroupingSets; setno++) { - if (peraggstate->sortstates[setno]) - tuplesort_end(peraggstate->sortstates[setno]); + if (pertrans->sortstates[setno]) + tuplesort_end(pertrans->sortstates[setno]); } } @@ -2646,7 +2956,7 @@ ExecReScanAgg(AggState *node) ExprContext *econtext = node->ss.ps.ps_ExprContext; PlanState *outerPlan = outerPlanState(node); Agg *aggnode = (Agg *) node->ss.ps.plan; - int aggno; + int transno; int numGroupingSets = Max(node->maxsets, 1); int setno; @@ -2678,16 +2988,16 @@ ExecReScanAgg(AggState *node) } /* Make sure we have closed any open tuplesorts */ - for (aggno = 0; aggno < node->numaggs; aggno++) + for (transno = 0; transno < node->numtrans; transno++) { for (setno = 0; setno < numGroupingSets; setno++) { - AggStatePerAgg peraggstate = &node->peragg[aggno]; + AggStatePerTrans pertrans = &node->pertrans[transno]; - if (peraggstate->sortstates[setno]) + if (pertrans->sortstates[setno]) { - tuplesort_end(peraggstate->sortstates[setno]); - peraggstate->sortstates[setno] = NULL; + tuplesort_end(pertrans->sortstates[setno]); + pertrans->sortstates[setno] = NULL; } } } @@ -2811,10 +3121,12 @@ AggGetAggref(FunctionCallInfo fcinfo) { if (fcinfo->context && IsA(fcinfo->context, AggState)) { - AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg; + AggStatePerTrans curpertrans; - if (curperagg) - return curperagg->aggref; + curpertrans = ((AggState *) fcinfo->context)->curpertrans; + + if (curpertrans) + return curpertrans->aggref; } return NULL; } diff --git a/src/backend/executor/nodeWindowAgg.c b/src/backend/executor/nodeWindowAgg.c index ecf96f8c19..c371d4db14 100644 --- a/src/backend/executor/nodeWindowAgg.c +++ b/src/backend/executor/nodeWindowAgg.c @@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, numArguments); /* build expression trees using actual argument & result types */ - build_aggregate_fnexprs(inputTypes, - numArguments, - 0, /* no ordered-set window functions yet */ - peraggstate->numFinalArgs, - false, /* no variadic window functions yet */ - aggtranstype, - wfunc->wintype, - wfunc->inputcollid, - transfn_oid, - invtransfn_oid, - finalfn_oid, - &transfnexpr, - &invtransfnexpr, - &finalfnexpr); + build_aggregate_transfn_expr(inputTypes, + numArguments, + 0, /* no ordered-set window functions yet */ + false, /* no variadic window functions yet */ + wfunc->wintype, + wfunc->inputcollid, + transfn_oid, + invtransfn_oid, + &transfnexpr, + &invtransfnexpr); /* set up infrastructure for calling the transfn(s) and finalfn */ fmgr_info(transfn_oid, &peraggstate->transfn); @@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, if (OidIsValid(finalfn_oid)) { + build_aggregate_finalfn_expr(inputTypes, + peraggstate->numFinalArgs, + aggtranstype, + wfunc->wintype, + wfunc->inputcollid, + finalfn_oid, + &finalfnexpr); fmgr_info(finalfn_oid, &peraggstate->finalfn); fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); } diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 3846b569d6..5b0d568478 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid, } /* - * Create expression trees for the transition and final functions - * of an aggregate. These are needed so that polymorphic functions - * can be used within an aggregate --- without the expression trees, - * such functions would not know the datatypes they are supposed to use. - * (The trees will never actually be executed, however, so we can skimp - * a bit on correctness.) + * Create an expression tree for the transition function of an aggregate. + * This is needed so that polymorphic functions can be used within an + * aggregate --- without the expression tree, such functions would not know + * the datatypes they are supposed to use. (The trees will never actually + * be executed, however, so we can skimp a bit on correctness.) * - * agg_input_types, agg_state_type, agg_result_type identify the input, - * transition, and result types of the aggregate. These should all be - * resolved to actual types (ie, none should ever be ANYELEMENT etc). + * agg_input_types and agg_state_type identifies the input types of the + * aggregate. These should be resolved to actual types (ie, none should + * ever be ANYELEMENT etc). * agg_input_collation is the aggregate function's input collation. * * For an ordered-set aggregate, remember that agg_input_types describes * the direct arguments followed by the aggregated arguments. * - * transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be - * called; the latter two may be InvalidOid. + * transfn_oid and invtransfn_oid identify the funcs to be called; the + * latter may be InvalidOid, however if invtransfn_oid is set then + * transfn_oid must also be set. * * Pointers to the constructed trees are returned into *transfnexpr, - * *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn, - * the respective pointers are set to NULL. Since use of the invtransfn is - * optional, NULL may be passed for invtransfnexpr. + * *invtransfnexpr. If there is no invtransfn, the respective pointer is set + * to NULL. Since use of the invtransfn is optional, NULL may be passed for + * invtransfnexpr. */ void -build_aggregate_fnexprs(Oid *agg_input_types, - int agg_num_inputs, - int agg_num_direct_inputs, - int num_finalfn_inputs, - bool agg_variadic, - Oid agg_state_type, - Oid agg_result_type, - Oid agg_input_collation, - Oid transfn_oid, - Oid invtransfn_oid, - Oid finalfn_oid, - Expr **transfnexpr, - Expr **invtransfnexpr, - Expr **finalfnexpr) +build_aggregate_transfn_expr(Oid *agg_input_types, + int agg_num_inputs, + int agg_num_direct_inputs, + bool agg_variadic, + Oid agg_state_type, + Oid agg_input_collation, + Oid transfn_oid, + Oid invtransfn_oid, + Expr **transfnexpr, + Expr **invtransfnexpr) { Param *argp; List *args; @@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types, else *invtransfnexpr = NULL; } +} - /* see if we have a final function */ - if (!OidIsValid(finalfn_oid)) - { - *finalfnexpr = NULL; - return; - } +/* + * Like build_aggregate_transfn_expr, but creates an expression tree for the + * final function of an aggregate, rather than the transition function. + */ +void +build_aggregate_finalfn_expr(Oid *agg_input_types, + int num_finalfn_inputs, + Oid agg_state_type, + Oid agg_result_type, + Oid agg_input_collation, + Oid finalfn_oid, + Expr **finalfnexpr) +{ + Param *argp; + List *args; + int i; /* * Build expr tree for final function diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h index 303fc3c1c7..5796de861c 100644 --- a/src/include/nodes/execnodes.h +++ b/src/include/nodes/execnodes.h @@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState typedef struct AggrefExprState { ExprState xprstate; - List *aggdirectargs; /* states of direct-argument expressions */ - List *args; /* states of aggregated-argument expressions */ - ExprState *aggfilter; /* state of FILTER expression, if any */ int aggno; /* ID number for agg within its plan node */ } AggrefExprState; @@ -1825,6 +1822,7 @@ typedef struct GroupState */ /* these structs are private in nodeAgg.c: */ typedef struct AggStatePerAggData *AggStatePerAgg; +typedef struct AggStatePerTransData *AggStatePerTrans; typedef struct AggStatePerGroupData *AggStatePerGroup; typedef struct AggStatePerPhaseData *AggStatePerPhase; @@ -1833,14 +1831,16 @@ typedef struct AggState ScanState ss; /* its first field is NodeTag */ List *aggs; /* all Aggref nodes in targetlist & quals */ int numaggs; /* length of list (could be zero!) */ + int numtrans; /* number of pertrans items */ AggStatePerPhase phase; /* pointer to current phase data */ int numphases; /* number of phases */ int current_phase; /* current phase number */ FmgrInfo *hashfunctions; /* per-grouping-field hash fns */ AggStatePerAgg peragg; /* per-Aggref information */ + AggStatePerTrans pertrans; /* per-Trans state information */ ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */ ExprContext *tmpcontext; /* econtext for input expressions */ - AggStatePerAgg curperagg; /* identifies currently active aggregate */ + AggStatePerTrans curpertrans; /* currently active trans state */ bool input_done; /* indicates end of input */ bool agg_done; /* indicates completion of Agg scan */ int projected_set; /* The last projected grouping set */ diff --git a/src/include/parser/parse_agg.h b/src/include/parser/parse_agg.h index 6a5f9bbdf1..e2b3894c28 100644 --- a/src/include/parser/parse_agg.h +++ b/src/include/parser/parse_agg.h @@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid, Oid *inputTypes, int numArguments); -extern void build_aggregate_fnexprs(Oid *agg_input_types, +extern void build_aggregate_transfn_expr(Oid *agg_input_types, int agg_num_inputs, int agg_num_direct_inputs, - int num_finalfn_inputs, bool agg_variadic, Oid agg_state_type, - Oid agg_result_type, Oid agg_input_collation, Oid transfn_oid, Oid invtransfn_oid, - Oid finalfn_oid, Expr **transfnexpr, - Expr **invtransfnexpr, + Expr **invtransfnexpr); + +extern void build_aggregate_finalfn_expr(Oid *agg_input_types, + int num_finalfn_inputs, + Oid agg_state_type, + Oid agg_result_type, + Oid agg_input_collation, + Oid finalfn_oid, Expr **finalfnexpr); #endif /* PARSE_AGG_H */ diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index 8852051e93..de826b5e50 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl; -4567890123456789 (1 row) +-- test aggregates with common transition functions share the same states +begin work; +create type avg_state as (total bigint, count bigint); +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); +-- aggregate state should be shared as aggs are the same. +select my_avg(one),my_avg(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_avg +--------+-------- + 2 | 2 +(1 row) + +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 4 +(1 row) + +-- shouldn't share states due to the filter clause not matching. +select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 3 | 4 +(1 row) + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); +NOTICE: avg_transfn called with 2 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 4 +NOTICE: avg_transfn called with 3 + my_avg | my_sum +--------+-------- + 2 | 6 +(1 row) + +-- test that aggs with the same sfunc and initcond share the same agg state +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(10,0)' +); +create aggregate my_avg_init2(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(4,0)' +); +-- state should be shared if INITCONDs are matching +select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 + my_sum_init | my_avg_init +-------------+------------- + 14 | 7 +(1 row) + +-- Varying INITCONDs should cause the states not to be shared. +select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one); +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 1 +NOTICE: avg_transfn called with 3 +NOTICE: avg_transfn called with 3 + my_sum_init | my_avg_init2 +-------------+-------------- + 14 | 4 +(1 row) + +rollback; +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); +NOTICE: sum_transfn called with 1 +NOTICE: sum_transfn called with 2 +NOTICE: sum_transfn called with 3 +NOTICE: sum_transfn called with 4 + my_sum | my_half_sum +--------+------------- + 10 | 5 +(1 row) + +rollback; diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index a84327d24c..8d501dc008 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -590,3 +590,168 @@ drop view aggordview1; -- variadic aggregates select least_agg(q1,q2) from int8_tbl; select least_agg(variadic array[q1,q2]) from int8_tbl; + + +-- test aggregates with common transition functions share the same states +begin work; + +create type avg_state as (total bigint, count bigint); + +create or replace function avg_transfn(state avg_state, n int) returns avg_state as +$$ +declare new_state avg_state; +begin + raise notice 'avg_transfn called with %', n; + if state is null then + if n is not null then + new_state.total := n; + new_state.count := 1; + return new_state; + end if; + return null; + elsif n is not null then + state.total := state.total + n; + state.count := state.count + 1; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function avg_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total / state.count; + end if; +end +$$ language plpgsql; + +create function sum_finalfn(state avg_state) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state.total; + end if; +end +$$ language plpgsql; + +create aggregate my_avg(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn +); + +create aggregate my_sum(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn +); + +-- aggregate state should be shared as aggs are the same. +select my_avg(one),my_avg(one) from (values(1),(3)) t(one); + +-- aggregate state should be shared as transfn is the same for both aggs. +select my_avg(one),my_sum(one) from (values(1),(3)) t(one); + +-- shouldn't share states due to the distinctness not matching. +select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one); + +-- shouldn't share states due to the filter clause not matching. +select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one); + +-- this should not share the state due to different input columns. +select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two); + +-- test that aggs with the same sfunc and initcond share the same agg state +create aggregate my_sum_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = sum_finalfn, + initcond = '(10,0)' +); + +create aggregate my_avg_init(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(10,0)' +); + +create aggregate my_avg_init2(int4) +( + stype = avg_state, + sfunc = avg_transfn, + finalfunc = avg_finalfn, + initcond = '(4,0)' +); + +-- state should be shared if INITCONDs are matching +select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one); + +-- Varying INITCONDs should cause the states not to be shared. +select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one); + +rollback; + +-- test aggregate state sharing to ensure it works if one aggregate has a +-- finalfn and the other one has none. +begin work; + +create or replace function sum_transfn(state int4, n int4) returns int4 as +$$ +declare new_state int4; +begin + raise notice 'sum_transfn called with %', n; + if state is null then + if n is not null then + new_state := n; + return new_state; + end if; + return null; + elsif n is not null then + state := state + n; + return state; + end if; + + return null; +end +$$ language plpgsql; + +create function halfsum_finalfn(state int4) returns int4 as +$$ +begin + if state is null then + return NULL; + else + return state / 2; + end if; +end +$$ language plpgsql; + +create aggregate my_sum(int4) +( + stype = int4, + sfunc = sum_transfn +); + +create aggregate my_half_sum(int4) +( + stype = int4, + sfunc = sum_transfn, + finalfunc = halfsum_finalfn +); + +-- Agg state should be shared even though my_sum has no finalfn +select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one); + +rollback;