diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index da6ef1a94c..a3454e52f6 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -532,13 +532,14 @@ static void select_current_set(AggState *aggstate, int setno, bool is_hash); static void initialize_phase(AggState *aggstate, int newphase); static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static void initialize_aggregates(AggState *aggstate, - AggStatePerGroup pergroup, + AggStatePerGroup *pergroups, int numReset); static void advance_transition_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); -static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, - AggStatePerGroup *pergroups); +static void advance_aggregates(AggState *aggstate, + AggStatePerGroup *sort_pergroups, + AggStatePerGroup *hash_pergroups); static void advance_combine_function(AggState *aggstate, AggStatePerTrans pertrans, AggStatePerGroup pergroupstate); @@ -793,14 +794,16 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans, * If there are multiple grouping sets, we initialize only the first numReset * of them (the grouping sets are ordered so that the most specific one, which * is reset most often, is first). As a convenience, if numReset is 0, we - * reinitialize all sets. numReset is -1 to initialize a hashtable entry, in - * which case the caller must have used select_current_set appropriately. + * reinitialize all sets. + * + * NB: This cannot be used for hash aggregates, as for those the grouping set + * number has to be specified from further up. * * When called, CurrentMemoryContext should be the per-query context. */ static void initialize_aggregates(AggState *aggstate, - AggStatePerGroup pergroup, + AggStatePerGroup *pergroups, int numReset) { int transno; @@ -812,31 +815,19 @@ initialize_aggregates(AggState *aggstate, if (numReset == 0) numReset = numGroupingSets; - for (transno = 0; transno < numTrans; transno++) + for (setno = 0; setno < numReset; setno++) { - AggStatePerTrans pertrans = &transstates[transno]; + AggStatePerGroup pergroup = pergroups[setno]; - if (numReset < 0) + select_current_set(aggstate, setno, false); + + for (transno = 0; transno < numTrans; transno++) { - AggStatePerGroup pergroupstate; - - pergroupstate = &pergroup[transno]; + AggStatePerTrans pertrans = &transstates[transno]; + AggStatePerGroup pergroupstate = &pergroup[transno]; initialize_aggregate(aggstate, pertrans, pergroupstate); } - else - { - for (setno = 0; setno < numReset; setno++) - { - AggStatePerGroup pergroupstate; - - pergroupstate = &pergroup[transno + (setno * numTrans)]; - - select_current_set(aggstate, setno, false); - - initialize_aggregate(aggstate, pertrans, pergroupstate); - } - } } } @@ -976,7 +967,9 @@ advance_transition_function(AggState *aggstate, * When called, CurrentMemoryContext should be the per-query context. */ static void -advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGroup *pergroups) +advance_aggregates(AggState *aggstate, + AggStatePerGroup *sort_pergroups, + AggStatePerGroup *hash_pergroups) { int transno; int setno = 0; @@ -1019,7 +1012,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro { /* DISTINCT and/or ORDER BY case */ Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff)); - Assert(!pergroups); + Assert(!hash_pergroups); /* * If the transfn is strict, we want to check for nullity before @@ -1090,7 +1083,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff]; } - if (pergroup) + if (sort_pergroups) { /* advance transition states for ordered grouping */ @@ -1100,13 +1093,13 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro select_current_set(aggstate, setno, false); - pergroupstate = &pergroup[transno + (setno * numTrans)]; + pergroupstate = &sort_pergroups[setno][transno]; advance_transition_function(aggstate, pertrans, pergroupstate); } } - if (pergroups) + if (hash_pergroups) { /* advance transition states for hashed grouping */ @@ -1116,7 +1109,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro select_current_set(aggstate, setno, true); - pergroupstate = &pergroups[setno][transno]; + pergroupstate = &hash_pergroups[setno][transno]; advance_transition_function(aggstate, pertrans, pergroupstate); } @@ -2095,12 +2088,25 @@ lookup_hash_entry(AggState *aggstate) if (isnew) { - entry->additional = (AggStatePerGroup) + AggStatePerGroup pergroup; + int transno; + + pergroup = (AggStatePerGroup) MemoryContextAlloc(perhash->hashtable->tablecxt, sizeof(AggStatePerGroupData) * aggstate->numtrans); - /* initialize aggregates for new tuple group */ - initialize_aggregates(aggstate, (AggStatePerGroup) entry->additional, - -1); + entry->additional = pergroup; + + /* + * Initialize aggregates for new tuple group, lookup_hash_entries() + * already has selected the relevant grouping set. + */ + for (transno = 0; transno < aggstate->numtrans; transno++) + { + AggStatePerTrans pertrans = &aggstate->pertrans[transno]; + AggStatePerGroup pergroupstate = &pergroup[transno]; + + initialize_aggregate(aggstate, pertrans, pergroupstate); + } } return entry; @@ -2184,7 +2190,7 @@ agg_retrieve_direct(AggState *aggstate) ExprContext *econtext; ExprContext *tmpcontext; AggStatePerAgg peragg; - AggStatePerGroup pergroup; + AggStatePerGroup *pergroups; AggStatePerGroup *hash_pergroups = NULL; TupleTableSlot *outerslot; TupleTableSlot *firstSlot; @@ -2207,7 +2213,7 @@ agg_retrieve_direct(AggState *aggstate) tmpcontext = aggstate->tmpcontext; peragg = aggstate->peragg; - pergroup = aggstate->pergroup; + pergroups = aggstate->pergroups; firstSlot = aggstate->ss.ss_ScanTupleSlot; /* @@ -2409,7 +2415,7 @@ agg_retrieve_direct(AggState *aggstate) /* * Initialize working state for a new input tuple group. */ - initialize_aggregates(aggstate, pergroup, numReset); + initialize_aggregates(aggstate, pergroups, numReset); if (aggstate->grp_firstTuple != NULL) { @@ -2446,9 +2452,9 @@ agg_retrieve_direct(AggState *aggstate) hash_pergroups = NULL; if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit)) - combine_aggregates(aggstate, pergroup); + combine_aggregates(aggstate, pergroups[0]); else - advance_aggregates(aggstate, pergroup, hash_pergroups); + advance_aggregates(aggstate, pergroups, hash_pergroups); /* Reset per-input-tuple context after each tuple */ ResetExprContext(tmpcontext); @@ -2512,7 +2518,7 @@ agg_retrieve_direct(AggState *aggstate) finalize_aggregates(aggstate, peragg, - pergroup + (currentSet * aggstate->numtrans)); + pergroups[currentSet]); /* * If there's no row to project right now, we must continue rather @@ -2756,7 +2762,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) aggstate->curpertrans = NULL; aggstate->input_done = false; aggstate->agg_done = false; - aggstate->pergroup = NULL; + aggstate->pergroups = NULL; aggstate->grp_firstTuple = NULL; aggstate->sort_in = NULL; aggstate->sort_out = NULL; @@ -3052,13 +3058,16 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) if (node->aggstrategy != AGG_HASHED) { - AggStatePerGroup pergroup; + AggStatePerGroup *pergroups; - pergroup = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData) - * numaggs - * numGroupingSets); + pergroups = (AggStatePerGroup *) palloc0(sizeof(AggStatePerGroup) * + numGroupingSets); - aggstate->pergroup = pergroup; + for (i = 0; i < numGroupingSets; i++) + pergroups[i] = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData) + * numaggs); + + aggstate->pergroups = pergroups; } /* @@ -4086,8 +4095,11 @@ ExecReScanAgg(AggState *node) /* * Reset the per-group state (in particular, mark transvalues null) */ - MemSet(node->pergroup, 0, - sizeof(AggStatePerGroupData) * node->numaggs * numGroupingSets); + for (setno = 0; setno < numGroupingSets; setno++) + { + MemSet(node->pergroups[setno], 0, + sizeof(AggStatePerGroupData) * node->numaggs); + } /* reset to phase 1 */ initialize_phase(node, 1); diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h index 94351eafad..bbc3ec3f3f 100644 --- a/src/include/nodes/execnodes.h +++ b/src/include/nodes/execnodes.h @@ -1852,13 +1852,15 @@ typedef struct AggState Tuplesortstate *sort_out; /* input is copied here for next phase */ TupleTableSlot *sort_slot; /* slot for sort results */ /* these fields are used in AGG_PLAIN and AGG_SORTED modes: */ - AggStatePerGroup pergroup; /* per-Aggref-per-group working state */ + AggStatePerGroup *pergroups; /* grouping set indexed array of per-group + * pointers */ HeapTuple grp_firstTuple; /* copy of first tuple of current group */ /* these fields are used in AGG_HASHED and AGG_MIXED modes: */ bool table_filled; /* hash table filled yet? */ int num_hashes; AggStatePerHash perhash; - AggStatePerGroup *hash_pergroup; /* array of per-group pointers */ + AggStatePerGroup *hash_pergroup; /* grouping set indexed array of + * per-group pointers */ /* support for evaluation of agg input expressions: */ ProjectionInfo *combinedproj; /* projection machinery */ } AggState;