diff --git a/src/backend/catalog/pg_aggregate.c b/src/backend/catalog/pg_aggregate.c index 66b4af93bd..246776093e 100644 --- a/src/backend/catalog/pg_aggregate.c +++ b/src/backend/catalog/pg_aggregate.c @@ -410,16 +410,17 @@ AggregateCreate(const char *aggName, Oid combineType; /* - * Combine function must have 2 argument, each of which is the trans - * type + * Combine function must have 2 arguments, each of which is the trans + * type. VARIADIC doesn't affect it. */ fnArgs[0] = aggTransType; fnArgs[1] = aggTransType; - combinefn = lookup_agg_function(aggcombinefnName, 2, fnArgs, - variadicArgType, &combineType); + combinefn = lookup_agg_function(aggcombinefnName, 2, + fnArgs, InvalidOid, + &combineType); - /* Ensure the return type matches the aggregates trans type */ + /* Ensure the return type matches the aggregate's trans type */ if (combineType != aggTransType) ereport(ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), @@ -429,14 +430,14 @@ AggregateCreate(const char *aggName, /* * A combine function to combine INTERNAL states must accept nulls and - * ensure that the returned state is in the correct memory context. + * ensure that the returned state is in the correct memory context. We + * cannot directly check the latter, but we can check the former. */ if (aggTransType == INTERNALOID && func_strict(combinefn)) ereport(ERROR, (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), errmsg("combine function with transition type %s must not be declared STRICT", format_type_be(aggTransType)))); - } /* @@ -444,10 +445,11 @@ AggregateCreate(const char *aggName, */ if (aggserialfnName) { + /* signature is always serialize(internal) returns bytea */ fnArgs[0] = INTERNALOID; serialfn = lookup_agg_function(aggserialfnName, 1, - fnArgs, variadicArgType, + fnArgs, InvalidOid, &rettype); if (rettype != BYTEAOID) @@ -463,11 +465,12 @@ AggregateCreate(const char *aggName, */ if (aggdeserialfnName) { + /* signature is always deserialize(bytea, internal) returns internal */ fnArgs[0] = BYTEAOID; fnArgs[1] = INTERNALOID; /* dummy argument for type safety */ deserialfn = lookup_agg_function(aggdeserialfnName, 2, - fnArgs, variadicArgType, + fnArgs, InvalidOid, &rettype); if (rettype != INTERNALOID) @@ -770,7 +773,11 @@ AggregateCreate(const char *aggName, /* * lookup_agg_function - * common code for finding transfn, invtransfn, finalfn, and combinefn + * common code for finding aggregate support functions + * + * fnName: possibly-schema-qualified function name + * nargs, input_types: expected function argument types + * variadicArgType: type of variadic argument if any, else InvalidOid * * Returns OID of function, and stores its return type into *rettype * diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index 7624a3ac6e..a27c292d4c 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -2940,8 +2940,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) ereport(ERROR, (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("combine function for aggregate %u must be declared as STRICT", - aggref->aggfnoid))); + errmsg("combine function with transition type %s must not be declared STRICT", + format_type_be(aggtranstype)))); } else {