From 29f45e299e7ffa1df0db44b8452228625479487f Mon Sep 17 00:00:00 2001 From: David Rowley Date: Wed, 7 Jul 2021 16:29:17 +1200 Subject: [PATCH] Use a hash table to speed up NOT IN(values) Similar to 50e17ad28, which allowed hash tables to be used for IN clauses with a set of constants, here we add the same feature for NOT IN clauses. NOT IN evaluates the same as: WHERE a <> v1 AND a <> v2 AND a <> v3. Obviously, if we're using a hash table we must be exactly equivalent to that and return the same result taking into account that either side of the condition could contain a NULL. This requires a little bit of special handling to make work with the hash table version. When processing NOT IN, the ScalarArrayOpExpr's operator will be the <> operator. To be able to build and lookup a hash table we must use the <>'s negator operator. The planner checks if that exists and is hashable and sets the relevant fields in ScalarArrayOpExpr to instruct the executor to use hashing. Author: David Rowley, James Coleman Reviewed-by: James Coleman, Zhihong Yu Discussion: https://postgr.es/m/CAApHDvoF1mum_FRk6D621edcB6KSHBi2+GAgWmioj5AhOu2vwQ@mail.gmail.com --- src/backend/executor/execExpr.c | 24 +++++-- src/backend/executor/execExprInterp.c | 17 ++++- src/backend/nodes/copyfuncs.c | 1 + src/backend/nodes/equalfuncs.c | 6 ++ src/backend/nodes/outfuncs.c | 1 + src/backend/nodes/readfuncs.c | 1 + src/backend/optimizer/plan/setrefs.c | 3 + src/backend/optimizer/prep/prepqual.c | 1 + src/backend/optimizer/util/clauses.c | 80 ++++++++++++++++----- src/backend/parser/parse_oper.c | 1 + src/backend/partitioning/partbounds.c | 1 + src/include/catalog/catversion.h | 2 +- src/include/executor/execExpr.h | 1 + src/include/nodes/primnodes.h | 18 +++-- src/test/regress/expected/expressions.out | 84 +++++++++++++++++++++++ src/test/regress/sql/expressions.sql | 30 ++++++++ 16 files changed, 242 insertions(+), 29 deletions(-) diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c index 8c9f8a6aeb..2c8c414a14 100644 --- a/src/backend/executor/execExpr.c +++ b/src/backend/executor/execExpr.c @@ -1205,19 +1205,34 @@ ExecInitExprRec(Expr *node, ExprState *state, AclResult aclresult; FmgrInfo *hash_finfo; FunctionCallInfo hash_fcinfo; + Oid cmpfuncid; + + /* + * Select the correct comparison function. When we do hashed + * NOT IN clauses, the opfuncid will be the inequality + * comparison function and negfuncid will be set to equality. + * We need to use the equality function for hash probes. + */ + if (OidIsValid(opexpr->negfuncid)) + { + Assert(OidIsValid(opexpr->hashfuncid)); + cmpfuncid = opexpr->negfuncid; + } + else + cmpfuncid = opexpr->opfuncid; Assert(list_length(opexpr->args) == 2); scalararg = (Expr *) linitial(opexpr->args); arrayarg = (Expr *) lsecond(opexpr->args); /* Check permission to call function */ - aclresult = pg_proc_aclcheck(opexpr->opfuncid, + aclresult = pg_proc_aclcheck(cmpfuncid, GetUserId(), ACL_EXECUTE); if (aclresult != ACLCHECK_OK) aclcheck_error(aclresult, OBJECT_FUNCTION, - get_func_name(opexpr->opfuncid)); - InvokeFunctionExecuteHook(opexpr->opfuncid); + get_func_name(cmpfuncid)); + InvokeFunctionExecuteHook(cmpfuncid); if (OidIsValid(opexpr->hashfuncid)) { @@ -1233,7 +1248,7 @@ ExecInitExprRec(Expr *node, ExprState *state, /* Set up the primary fmgr lookup information */ finfo = palloc0(sizeof(FmgrInfo)); fcinfo = palloc0(SizeForFunctionCallInfo(2)); - fmgr_info(opexpr->opfuncid, finfo); + fmgr_info(cmpfuncid, finfo); fmgr_info_set_expr((Node *) node, finfo); InitFunctionCallInfoData(*fcinfo, finfo, 2, opexpr->inputcollid, NULL, NULL); @@ -1274,6 +1289,7 @@ ExecInitExprRec(Expr *node, ExprState *state, /* And perform the operation */ scratch.opcode = EEOP_HASHED_SCALARARRAYOP; + scratch.d.hashedscalararrayop.inclause = opexpr->useOr; scratch.d.hashedscalararrayop.finfo = finfo; scratch.d.hashedscalararrayop.fcinfo_data = fcinfo; scratch.d.hashedscalararrayop.fn_addr = finfo->fn_addr; diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c index 5483dee650..eb49817cee 100644 --- a/src/backend/executor/execExprInterp.c +++ b/src/backend/executor/execExprInterp.c @@ -3493,6 +3493,7 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco { ScalarArrayOpExprHashTable *elements_tab = op->d.hashedscalararrayop.elements_tab; FunctionCallInfo fcinfo = op->d.hashedscalararrayop.fcinfo_data; + bool inclause = op->d.hashedscalararrayop.inclause; bool strictfunc = op->d.hashedscalararrayop.finfo->fn_strict; Datum scalar = fcinfo->args[0].value; bool scalar_isnull = fcinfo->args[0].isnull; @@ -3596,7 +3597,12 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco /* Check the hash to see if we have a match. */ hashfound = NULL != saophash_lookup(elements_tab->hashtab, scalar); - result = BoolGetDatum(hashfound); + /* the result depends on if the clause is an IN or NOT IN clause */ + if (inclause) + result = BoolGetDatum(hashfound); /* IN */ + else + result = BoolGetDatum(!hashfound); /* NOT IN */ + resultnull = false; /* @@ -3605,7 +3611,7 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco * hashtable, but instead marked if we found any when building the table * in has_nulls. */ - if (!DatumGetBool(result) && op->d.hashedscalararrayop.has_nulls) + if (!hashfound && op->d.hashedscalararrayop.has_nulls) { if (strictfunc) { @@ -3633,6 +3639,13 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco result = op->d.hashedscalararrayop.fn_addr(fcinfo); resultnull = fcinfo->isnull; + + /* + * Reverse the result for NOT IN clauses since the above function + * is the equality function and we need not-equals. + */ + if (!inclause) + result = !result; } } diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c index bd87f23784..6fef067957 100644 --- a/src/backend/nodes/copyfuncs.c +++ b/src/backend/nodes/copyfuncs.c @@ -1718,6 +1718,7 @@ _copyScalarArrayOpExpr(const ScalarArrayOpExpr *from) COPY_SCALAR_FIELD(opno); COPY_SCALAR_FIELD(opfuncid); COPY_SCALAR_FIELD(hashfuncid); + COPY_SCALAR_FIELD(negfuncid); COPY_SCALAR_FIELD(useOr); COPY_SCALAR_FIELD(inputcollid); COPY_NODE_FIELD(args); diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c index dba3e6b31e..b9cc7b199c 100644 --- a/src/backend/nodes/equalfuncs.c +++ b/src/backend/nodes/equalfuncs.c @@ -414,6 +414,12 @@ _equalScalarArrayOpExpr(const ScalarArrayOpExpr *a, const ScalarArrayOpExpr *b) b->hashfuncid != 0) return false; + /* Likewise for the negfuncid */ + if (a->negfuncid != b->negfuncid && + a->negfuncid != 0 && + b->negfuncid != 0) + return false; + COMPARE_SCALAR_FIELD(useOr); COMPARE_SCALAR_FIELD(inputcollid); COMPARE_NODE_FIELD(args); diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index e32b92e299..e09e4f77fe 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -1311,6 +1311,7 @@ _outScalarArrayOpExpr(StringInfo str, const ScalarArrayOpExpr *node) WRITE_OID_FIELD(opno); WRITE_OID_FIELD(opfuncid); WRITE_OID_FIELD(hashfuncid); + WRITE_OID_FIELD(negfuncid); WRITE_BOOL_FIELD(useOr); WRITE_OID_FIELD(inputcollid); WRITE_NODE_FIELD(args); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index f0b34ecfac..3dec0a2508 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -832,6 +832,7 @@ _readScalarArrayOpExpr(void) READ_OID_FIELD(opno); READ_OID_FIELD(opfuncid); READ_OID_FIELD(hashfuncid); + READ_OID_FIELD(negfuncid); READ_BOOL_FIELD(useOr); READ_OID_FIELD(inputcollid); READ_NODE_FIELD(args); diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c index 61ccfd300b..210c4b3b14 100644 --- a/src/backend/optimizer/plan/setrefs.c +++ b/src/backend/optimizer/plan/setrefs.c @@ -1687,6 +1687,9 @@ fix_expr_common(PlannerInfo *root, Node *node) if (!OidIsValid(saop->hashfuncid)) record_plan_function_dependency(root, saop->hashfuncid); + + if (!OidIsValid(saop->negfuncid)) + record_plan_function_dependency(root, saop->negfuncid); } else if (IsA(node, Const)) { diff --git a/src/backend/optimizer/prep/prepqual.c b/src/backend/optimizer/prep/prepqual.c index 42c3e4dc04..8908a9652e 100644 --- a/src/backend/optimizer/prep/prepqual.c +++ b/src/backend/optimizer/prep/prepqual.c @@ -128,6 +128,7 @@ negate_clause(Node *node) newopexpr->opno = negator; newopexpr->opfuncid = InvalidOid; newopexpr->hashfuncid = InvalidOid; + newopexpr->negfuncid = InvalidOid; newopexpr->useOr = !saopexpr->useOr; newopexpr->inputcollid = saopexpr->inputcollid; newopexpr->args = saopexpr->args; diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 059fa70254..8506165d68 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -2140,27 +2140,71 @@ convert_saop_to_hashed_saop_walker(Node *node, void *context) Oid lefthashfunc; Oid righthashfunc; - if (saop->useOr && arrayarg && IsA(arrayarg, Const) && - !((Const *) arrayarg)->constisnull && - get_op_hash_functions(saop->opno, &lefthashfunc, &righthashfunc) && - lefthashfunc == righthashfunc) + if (arrayarg && IsA(arrayarg, Const) && + !((Const *) arrayarg)->constisnull) { - Datum arrdatum = ((Const *) arrayarg)->constvalue; - ArrayType *arr = (ArrayType *) DatumGetPointer(arrdatum); - int nitems; - - /* - * Only fill in the hash functions if the array looks large enough - * for it to be worth hashing instead of doing a linear search. - */ - nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr)); - - if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP) + if (saop->useOr) { - /* Looks good. Fill in the hash functions */ - saop->hashfuncid = lefthashfunc; + if (get_op_hash_functions(saop->opno, &lefthashfunc, &righthashfunc) && + lefthashfunc == righthashfunc) + { + Datum arrdatum = ((Const *) arrayarg)->constvalue; + ArrayType *arr = (ArrayType *) DatumGetPointer(arrdatum); + int nitems; + + /* + * Only fill in the hash functions if the array looks + * large enough for it to be worth hashing instead of + * doing a linear search. + */ + nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr)); + + if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP) + { + /* Looks good. Fill in the hash functions */ + saop->hashfuncid = lefthashfunc; + } + return true; + } + } + else /* !saop->useOr */ + { + Oid negator = get_negator(saop->opno); + + /* + * Check if this is a NOT IN using an operator whose negator + * is hashable. If so we can still build a hash table and + * just ensure the lookup items are not in the hash table. + */ + if (OidIsValid(negator) && + get_op_hash_functions(negator, &lefthashfunc, &righthashfunc) && + lefthashfunc == righthashfunc) + { + Datum arrdatum = ((Const *) arrayarg)->constvalue; + ArrayType *arr = (ArrayType *) DatumGetPointer(arrdatum); + int nitems; + + /* + * Only fill in the hash functions if the array looks + * large enough for it to be worth hashing instead of + * doing a linear search. + */ + nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr)); + + if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP) + { + /* Looks good. Fill in the hash functions */ + saop->hashfuncid = lefthashfunc; + + /* + * Also set the negfuncid. The executor will need + * that to perform hashtable lookups. + */ + saop->negfuncid = get_opcode(negator); + } + return true; + } } - return true; } } diff --git a/src/backend/parser/parse_oper.c b/src/backend/parser/parse_oper.c index 4e46079990..bc34a23afc 100644 --- a/src/backend/parser/parse_oper.c +++ b/src/backend/parser/parse_oper.c @@ -895,6 +895,7 @@ make_scalar_array_op(ParseState *pstate, List *opname, result->opno = oprid(tup); result->opfuncid = opform->oprcode; result->hashfuncid = InvalidOid; + result->negfuncid = InvalidOid; result->useOr = useOr; /* inputcollid will be set by parse_collate.c */ result->args = args; diff --git a/src/backend/partitioning/partbounds.c b/src/backend/partitioning/partbounds.c index 00c394445a..38baaefcda 100644 --- a/src/backend/partitioning/partbounds.c +++ b/src/backend/partitioning/partbounds.c @@ -3878,6 +3878,7 @@ make_partition_op_expr(PartitionKey key, int keynum, saopexpr->opno = operoid; saopexpr->opfuncid = get_opcode(operoid); saopexpr->hashfuncid = InvalidOid; + saopexpr->negfuncid = InvalidOid; saopexpr->useOr = true; saopexpr->inputcollid = key->partcollation[keynum]; saopexpr->args = list_make2(arg1, arrexpr); diff --git a/src/include/catalog/catversion.h b/src/include/catalog/catversion.h index 1b23c7c253..e92ecaf344 100644 --- a/src/include/catalog/catversion.h +++ b/src/include/catalog/catversion.h @@ -53,6 +53,6 @@ */ /* yyyymmddN */ -#define CATALOG_VERSION_NO 202106151 +#define CATALOG_VERSION_NO 202107071 #endif diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h index 785600d04d..6a24341faa 100644 --- a/src/include/executor/execExpr.h +++ b/src/include/executor/execExpr.h @@ -574,6 +574,7 @@ typedef struct ExprEvalStep struct { bool has_nulls; + bool inclause; /* true for IN and false for NOT IN */ struct ScalarArrayOpExprHashTable *elements_tab; FmgrInfo *finfo; /* function's lookup data */ FunctionCallInfo fcinfo_data; /* arguments etc */ diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h index 9ae851d847..996c3e4016 100644 --- a/src/include/nodes/primnodes.h +++ b/src/include/nodes/primnodes.h @@ -580,10 +580,18 @@ typedef OpExpr NullIfExpr; * the result type (or the collation) because it must be boolean. * * A ScalarArrayOpExpr with a valid hashfuncid is evaluated during execution - * by building a hash table containing the Const values from the rhs arg. - * This table is probed during expression evaluation. Only useOr=true - * ScalarArrayOpExpr with Const arrays on the rhs can have the hashfuncid - * field set. See convert_saop_to_hashed_saop(). + * by building a hash table containing the Const values from the RHS arg. + * This table is probed during expression evaluation. The planner will set + * hashfuncid to the hash function which must be used to build and probe the + * hash table. The executor determines if it should use hash-based checks or + * the more traditional means based on if the hashfuncid is set or not. + * + * When performing hashed NOT IN, the negfuncid will also be set to the + * equality function which the hash table must use to build and probe the hash + * table. opno and opfuncid will remain set to the <> operator and its + * corresponding function and won't be used during execution. For + * non-hashtable based NOT INs, negfuncid will be set to InvalidOid. See + * convert_saop_to_hashed_saop(). */ typedef struct ScalarArrayOpExpr { @@ -591,6 +599,8 @@ typedef struct ScalarArrayOpExpr Oid opno; /* PG_OPERATOR OID of the operator */ Oid opfuncid; /* PG_PROC OID of comparison function */ Oid hashfuncid; /* PG_PROC OID of hash func or InvalidOid */ + Oid negfuncid; /* PG_PROC OID of negator of opfuncid function + * or InvalidOid. See above */ bool useOr; /* true for ANY, false for ALL */ Oid inputcollid; /* OID of collation that operator should use */ List *args; /* the scalar and array operands */ diff --git a/src/test/regress/expected/expressions.out b/src/test/regress/expected/expressions.out index 5944dfd5e1..84159cb21f 100644 --- a/src/test/regress/expected/expressions.out +++ b/src/test/regress/expected/expressions.out @@ -216,6 +216,55 @@ select return_text_input('a') in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', ' t (1 row) +-- NOT IN +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1); + ?column? +---------- + f +(1 row) + +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 0); + ?column? +---------- + t +(1 row) + +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 2, null); + ?column? +---------- + +(1 row) + +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null); + ?column? +---------- + f +(1 row) + +select return_int_input(1) not in (null, null, null, null, null, null, null, null, null, null, null); + ?column? +---------- + +(1 row) + +select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1); + ?column? +---------- + +(1 row) + +select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, null); + ?column? +---------- + +(1 row) + +select return_text_input('a') not in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'); + ?column? +---------- + f +(1 row) + rollback; -- Test with non-strict equality function. -- We need to create our own type for this. @@ -242,6 +291,11 @@ begin end if; end; $$ language plpgsql immutable; +create function myintne(myint, myint) returns bool as $$ +begin + return not myinteq($1, $2); +end; +$$ language plpgsql immutable; create operator = ( leftarg = myint, rightarg = myint, @@ -252,6 +306,16 @@ create operator = ( join = eqjoinsel, merges ); +create operator <> ( + leftarg = myint, + rightarg = myint, + commutator = <>, + negator = =, + procedure = myintne, + restrict = eqsel, + join = eqjoinsel, + merges +); create operator class myint_ops default for type myint using hash as operator 1 = (myint, myint), @@ -266,6 +330,16 @@ select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6 (2 rows) +select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null); + a +--- +(0 rows) + +select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null); + a +--- +(0 rows) + -- ensure the result matched with the non-hashed version. We simply remove -- some array elements so that we don't reach the hashing threshold. select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null); @@ -275,4 +349,14 @@ select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, (2 rows) +select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null); + a +--- +(0 rows) + +select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null); + a +--- +(0 rows) + rollback; diff --git a/src/test/regress/sql/expressions.sql b/src/test/regress/sql/expressions.sql index b3fd1b5ecb..bf30f41505 100644 --- a/src/test/regress/sql/expressions.sql +++ b/src/test/regress/sql/expressions.sql @@ -93,6 +93,15 @@ select return_int_input(1) in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null); select return_int_input(null::int) in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1); select return_int_input(null::int) in (10, 9, 2, 8, 3, 7, 4, 6, 5, null); select return_text_input('a') in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'); +-- NOT IN +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1); +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 0); +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 2, null); +select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null); +select return_int_input(1) not in (null, null, null, null, null, null, null, null, null, null, null); +select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1); +select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, null); +select return_text_input('a') not in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'); rollback; @@ -124,6 +133,12 @@ begin end; $$ language plpgsql immutable; +create function myintne(myint, myint) returns bool as $$ +begin + return not myinteq($1, $2); +end; +$$ language plpgsql immutable; + create operator = ( leftarg = myint, rightarg = myint, @@ -135,6 +150,17 @@ create operator = ( merges ); +create operator <> ( + leftarg = myint, + rightarg = myint, + commutator = <>, + negator = =, + procedure = myintne, + restrict = eqsel, + join = eqjoinsel, + merges +); + create operator class myint_ops default for type myint using hash as operator 1 = (myint, myint), @@ -145,8 +171,12 @@ insert into inttest values(1::myint),(null); -- try an array with enough elements to cause hashing select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null); +select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null); +select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null); -- ensure the result matched with the non-hashed version. We simply remove -- some array elements so that we don't reach the hashing threshold. select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null); +select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null); +select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null); rollback;