postgresql/src/backend/rewrite/rewriteSearchCycle.c

682 lines
21 KiB
C

/*-------------------------------------------------------------------------
*
* rewriteSearchCycle.c
* Support for rewriting SEARCH and CYCLE clauses.
*
* Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group
* Portions Copyright (c) 1994, Regents of the University of California
*
* IDENTIFICATION
* src/backend/rewrite/rewriteSearchCycle.c
*
*-------------------------------------------------------------------------
*/
#include "postgres.h"
#include "catalog/pg_operator_d.h"
#include "catalog/pg_type_d.h"
#include "nodes/makefuncs.h"
#include "nodes/parsenodes.h"
#include "nodes/pg_list.h"
#include "nodes/primnodes.h"
#include "parser/analyze.h"
#include "parser/parsetree.h"
#include "rewrite/rewriteManip.h"
#include "rewrite/rewriteSearchCycle.h"
#include "utils/fmgroids.h"
/*----------
* Rewrite a CTE with SEARCH or CYCLE clause
*
* Consider a CTE like
*
* WITH RECURSIVE ctename (col1, col2, col3) AS (
* query1
* UNION [ALL]
* SELECT trosl FROM ctename
* )
*
* With a search clause
*
* SEARCH BREADTH FIRST BY col1, col2 SET sqc
*
* the CTE is rewritten to
*
* WITH RECURSIVE ctename (col1, col2, col3, sqc) AS (
* SELECT col1, col2, col3, -- original WITH column list
* ROW(0, col1, col2) -- initial row of search columns
* FROM (query1) "*TLOCRN*" (col1, col2, col3)
* UNION [ALL]
* SELECT col1, col2, col3, -- same as above
* ROW(sqc.depth + 1, col1, col2) -- count depth
* FROM (SELECT trosl, ctename.sqc FROM ctename) "*TROCRN*" (col1, col2, col3, sqc)
* )
*
* (This isn't quite legal SQL: sqc.depth is meant to refer to the first
* column of sqc, which has a row type, but the field names are not defined
* here. Representing this properly in SQL would be more complicated (and the
* SQL standard actually does it in that more complicated way), but the
* internal representation allows us to construct it this way.)
*
* With a search clause
*
* SEARCH DEPTH FIRST BY col1, col2 SET sqc
*
* the CTE is rewritten to
*
* WITH RECURSIVE ctename (col1, col2, col3, sqc) AS (
* SELECT col1, col2, col3, -- original WITH column list
* ARRAY[ROW(col1, col2)] -- initial row of search columns
* FROM (query1) "*TLOCRN*" (col1, col2, col3)
* UNION [ALL]
* SELECT col1, col2, col3, -- same as above
* sqc || ARRAY[ROW(col1, col2)] -- record rows seen
* FROM (SELECT trosl, ctename.sqc FROM ctename) "*TROCRN*" (col1, col2, col3, sqc)
* )
*
* With a cycle clause
*
* CYCLE col1, col2 SET cmc TO 'Y' DEFAULT 'N' USING cpa
*
* (cmc = cycle mark column, cpa = cycle path) the CTE is rewritten to
*
* WITH RECURSIVE ctename (col1, col2, col3, cmc, cpa) AS (
* SELECT col1, col2, col3, -- original WITH column list
* 'N', -- cycle mark default
* ARRAY[ROW(col1, col2)] -- initial row of cycle columns
* FROM (query1) "*TLOCRN*" (col1, col2, col3)
* UNION [ALL]
* SELECT col1, col2, col3, -- same as above
* CASE WHEN ROW(col1, col2) = ANY (ARRAY[cpa]) THEN 'Y' ELSE 'N' END, -- compute cycle mark column
* cpa || ARRAY[ROW(col1, col2)] -- record rows seen
* FROM (SELECT trosl, ctename.cmc, ctename.cpa FROM ctename) "*TROCRN*" (col1, col2, col3, cmc, cpa)
* WHERE cmc <> 'Y'
* )
*
* The expression to compute the cycle mark column in the right-hand query is
* written as
*
* CASE WHEN ROW(col1, col2) IN (SELECT p.* FROM TABLE(cpa) p) THEN cmv ELSE cmd END
*
* in the SQL standard, but in PostgreSQL we can use the scalar-array operator
* expression shown above.
*
* Also, in some of the cases where operators are shown above we actually
* directly produce the underlying function call.
*
* If both a search clause and a cycle clause is specified, then the search
* clause column is added before the cycle clause columns.
*/
/*
* Make a RowExpr from the specified column names, which have to be among the
* output columns of the CTE.
*/
static RowExpr *
make_path_rowexpr(const CommonTableExpr *cte, const List *col_list)
{
RowExpr *rowexpr;
ListCell *lc;
rowexpr = makeNode(RowExpr);
rowexpr->row_typeid = RECORDOID;
rowexpr->row_format = COERCE_IMPLICIT_CAST;
rowexpr->location = -1;
foreach(lc, col_list)
{
char *colname = strVal(lfirst(lc));
for (int i = 0; i < list_length(cte->ctecolnames); i++)
{
char *colname2 = strVal(list_nth(cte->ctecolnames, i));
if (strcmp(colname, colname2) == 0)
{
Var *var;
var = makeVar(1, i + 1,
list_nth_oid(cte->ctecoltypes, i),
list_nth_int(cte->ctecoltypmods, i),
list_nth_oid(cte->ctecolcollations, i),
0);
rowexpr->args = lappend(rowexpr->args, var);
rowexpr->colnames = lappend(rowexpr->colnames, makeString(colname));
break;
}
}
}
return rowexpr;
}
/*
* Wrap a RowExpr in an ArrayExpr, for the initial search depth first or cycle
* row.
*/
static Expr *
make_path_initial_array(RowExpr *rowexpr)
{
ArrayExpr *arr;
arr = makeNode(ArrayExpr);
arr->array_typeid = RECORDARRAYOID;
arr->element_typeid = RECORDOID;
arr->location = -1;
arr->elements = list_make1(rowexpr);
return (Expr *) arr;
}
/*
* Make an array catenation expression like
*
* cpa || ARRAY[ROW(cols)]
*
* where the varattno of cpa is provided as path_varattno.
*/
static Expr *
make_path_cat_expr(RowExpr *rowexpr, AttrNumber path_varattno)
{
ArrayExpr *arr;
FuncExpr *fexpr;
arr = makeNode(ArrayExpr);
arr->array_typeid = RECORDARRAYOID;
arr->element_typeid = RECORDOID;
arr->location = -1;
arr->elements = list_make1(rowexpr);
fexpr = makeFuncExpr(F_ARRAY_CAT, RECORDARRAYOID,
list_make2(makeVar(1, path_varattno, RECORDARRAYOID, -1, 0, 0),
arr),
InvalidOid, InvalidOid, COERCE_EXPLICIT_CALL);
return (Expr *) fexpr;
}
/*
* The real work happens here.
*/
CommonTableExpr *
rewriteSearchAndCycle(CommonTableExpr *cte)
{
Query *ctequery;
SetOperationStmt *sos;
int rti1,
rti2;
RangeTblEntry *rte1,
*rte2,
*newrte;
Query *newq1,
*newq2;
Query *newsubquery;
RangeTblRef *rtr;
Oid search_seq_type = InvalidOid;
AttrNumber sqc_attno = InvalidAttrNumber;
AttrNumber cmc_attno = InvalidAttrNumber;
AttrNumber cpa_attno = InvalidAttrNumber;
TargetEntry *tle;
RowExpr *cycle_col_rowexpr = NULL;
RowExpr *search_col_rowexpr = NULL;
List *ewcl;
int cte_rtindex = -1;
Assert(cte->search_clause || cte->cycle_clause);
cte = copyObject(cte);
ctequery = castNode(Query, cte->ctequery);
/*
* The top level of the CTE's query should be a UNION. Find the two
* subqueries.
*/
Assert(ctequery->setOperations);
sos = castNode(SetOperationStmt, ctequery->setOperations);
Assert(sos->op == SETOP_UNION);
rti1 = castNode(RangeTblRef, sos->larg)->rtindex;
rti2 = castNode(RangeTblRef, sos->rarg)->rtindex;
rte1 = rt_fetch(rti1, ctequery->rtable);
rte2 = rt_fetch(rti2, ctequery->rtable);
Assert(rte1->rtekind == RTE_SUBQUERY);
Assert(rte2->rtekind == RTE_SUBQUERY);
/*
* We'll need this a few times later.
*/
if (cte->search_clause)
{
if (cte->search_clause->search_breadth_first)
search_seq_type = RECORDOID;
else
search_seq_type = RECORDARRAYOID;
}
/*
* Attribute numbers of the added columns in the CTE's column list
*/
if (cte->search_clause)
sqc_attno = list_length(cte->ctecolnames) + 1;
if (cte->cycle_clause)
{
cmc_attno = list_length(cte->ctecolnames) + 1;
cpa_attno = list_length(cte->ctecolnames) + 2;
if (cte->search_clause)
{
cmc_attno++;
cpa_attno++;
}
}
/*
* Make new left subquery
*/
newq1 = makeNode(Query);
newq1->commandType = CMD_SELECT;
newq1->canSetTag = true;
newrte = makeNode(RangeTblEntry);
newrte->rtekind = RTE_SUBQUERY;
newrte->alias = makeAlias("*TLOCRN*", cte->ctecolnames);
newrte->eref = newrte->alias;
newsubquery = copyObject(rte1->subquery);
IncrementVarSublevelsUp((Node *) newsubquery, 1, 1);
newrte->subquery = newsubquery;
newrte->inFromCl = true;
newq1->rtable = list_make1(newrte);
rtr = makeNode(RangeTblRef);
rtr->rtindex = 1;
newq1->jointree = makeFromExpr(list_make1(rtr), NULL);
/*
* Make target list
*/
for (int i = 0; i < list_length(cte->ctecolnames); i++)
{
Var *var;
var = makeVar(1, i + 1,
list_nth_oid(cte->ctecoltypes, i),
list_nth_int(cte->ctecoltypmods, i),
list_nth_oid(cte->ctecolcollations, i),
0);
tle = makeTargetEntry((Expr *) var, i + 1, strVal(list_nth(cte->ctecolnames, i)), false);
tle->resorigtbl = list_nth_node(TargetEntry, rte1->subquery->targetList, i)->resorigtbl;
tle->resorigcol = list_nth_node(TargetEntry, rte1->subquery->targetList, i)->resorigcol;
newq1->targetList = lappend(newq1->targetList, tle);
}
if (cte->search_clause)
{
Expr *texpr;
search_col_rowexpr = make_path_rowexpr(cte, cte->search_clause->search_col_list);
if (cte->search_clause->search_breadth_first)
{
search_col_rowexpr->args = lcons(makeConst(INT8OID, -1, InvalidOid, sizeof(int64),
Int64GetDatum(0), false, FLOAT8PASSBYVAL),
search_col_rowexpr->args);
search_col_rowexpr->colnames = lcons(makeString("*DEPTH*"), search_col_rowexpr->colnames);
texpr = (Expr *) search_col_rowexpr;
}
else
texpr = make_path_initial_array(search_col_rowexpr);
tle = makeTargetEntry(texpr,
list_length(newq1->targetList) + 1,
cte->search_clause->search_seq_column,
false);
newq1->targetList = lappend(newq1->targetList, tle);
}
if (cte->cycle_clause)
{
tle = makeTargetEntry((Expr *) cte->cycle_clause->cycle_mark_default,
list_length(newq1->targetList) + 1,
cte->cycle_clause->cycle_mark_column,
false);
newq1->targetList = lappend(newq1->targetList, tle);
cycle_col_rowexpr = make_path_rowexpr(cte, cte->cycle_clause->cycle_col_list);
tle = makeTargetEntry(make_path_initial_array(cycle_col_rowexpr),
list_length(newq1->targetList) + 1,
cte->cycle_clause->cycle_path_column,
false);
newq1->targetList = lappend(newq1->targetList, tle);
}
rte1->subquery = newq1;
if (cte->search_clause)
{
rte1->eref->colnames = lappend(rte1->eref->colnames, makeString(cte->search_clause->search_seq_column));
}
if (cte->cycle_clause)
{
rte1->eref->colnames = lappend(rte1->eref->colnames, makeString(cte->cycle_clause->cycle_mark_column));
rte1->eref->colnames = lappend(rte1->eref->colnames, makeString(cte->cycle_clause->cycle_path_column));
}
/*
* Make new right subquery
*/
newq2 = makeNode(Query);
newq2->commandType = CMD_SELECT;
newq2->canSetTag = true;
newrte = makeNode(RangeTblEntry);
newrte->rtekind = RTE_SUBQUERY;
ewcl = copyObject(cte->ctecolnames);
if (cte->search_clause)
{
ewcl = lappend(ewcl, makeString(cte->search_clause->search_seq_column));
}
if (cte->cycle_clause)
{
ewcl = lappend(ewcl, makeString(cte->cycle_clause->cycle_mark_column));
ewcl = lappend(ewcl, makeString(cte->cycle_clause->cycle_path_column));
}
newrte->alias = makeAlias("*TROCRN*", ewcl);
newrte->eref = newrte->alias;
/*
* Find the reference to the recursive CTE in the right UNION subquery's
* range table. We expect it to be two levels up from the UNION subquery
* (and must check that to avoid being fooled by sub-WITHs with the same
* CTE name). There will not be more than one such reference, because the
* parser would have rejected that (see checkWellFormedRecursion() in
* parse_cte.c). However, the parser doesn't insist that the reference
* appear in the UNION subquery's topmost range table, so we might fail to
* find it at all. That's an unimplemented case for the moment.
*/
for (int rti = 1; rti <= list_length(rte2->subquery->rtable); rti++)
{
RangeTblEntry *e = rt_fetch(rti, rte2->subquery->rtable);
if (e->rtekind == RTE_CTE &&
strcmp(cte->ctename, e->ctename) == 0 &&
e->ctelevelsup == 2)
{
cte_rtindex = rti;
break;
}
}
if (cte_rtindex <= 0)
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("with a SEARCH or CYCLE clause, the recursive reference to WITH query \"%s\" must be at the top level of its right-hand SELECT",
cte->ctename)));
newsubquery = copyObject(rte2->subquery);
IncrementVarSublevelsUp((Node *) newsubquery, 1, 1);
/*
* Add extra columns to target list of subquery of right subquery
*/
if (cte->search_clause)
{
Var *var;
/* ctename.sqc */
var = makeVar(cte_rtindex, sqc_attno,
search_seq_type, -1, InvalidOid, 0);
tle = makeTargetEntry((Expr *) var,
list_length(newsubquery->targetList) + 1,
cte->search_clause->search_seq_column,
false);
newsubquery->targetList = lappend(newsubquery->targetList, tle);
}
if (cte->cycle_clause)
{
Var *var;
/* ctename.cmc */
var = makeVar(cte_rtindex, cmc_attno,
cte->cycle_clause->cycle_mark_type,
cte->cycle_clause->cycle_mark_typmod,
cte->cycle_clause->cycle_mark_collation, 0);
tle = makeTargetEntry((Expr *) var,
list_length(newsubquery->targetList) + 1,
cte->cycle_clause->cycle_mark_column,
false);
newsubquery->targetList = lappend(newsubquery->targetList, tle);
/* ctename.cpa */
var = makeVar(cte_rtindex, cpa_attno,
RECORDARRAYOID, -1, InvalidOid, 0);
tle = makeTargetEntry((Expr *) var,
list_length(newsubquery->targetList) + 1,
cte->cycle_clause->cycle_path_column,
false);
newsubquery->targetList = lappend(newsubquery->targetList, tle);
}
newrte->subquery = newsubquery;
newrte->inFromCl = true;
newq2->rtable = list_make1(newrte);
rtr = makeNode(RangeTblRef);
rtr->rtindex = 1;
if (cte->cycle_clause)
{
Expr *expr;
/*
* Add cmc <> cmv condition
*/
expr = make_opclause(cte->cycle_clause->cycle_mark_neop, BOOLOID, false,
(Expr *) makeVar(1, cmc_attno,
cte->cycle_clause->cycle_mark_type,
cte->cycle_clause->cycle_mark_typmod,
cte->cycle_clause->cycle_mark_collation, 0),
(Expr *) cte->cycle_clause->cycle_mark_value,
InvalidOid,
cte->cycle_clause->cycle_mark_collation);
newq2->jointree = makeFromExpr(list_make1(rtr), (Node *) expr);
}
else
newq2->jointree = makeFromExpr(list_make1(rtr), NULL);
/*
* Make target list
*/
for (int i = 0; i < list_length(cte->ctecolnames); i++)
{
Var *var;
var = makeVar(1, i + 1,
list_nth_oid(cte->ctecoltypes, i),
list_nth_int(cte->ctecoltypmods, i),
list_nth_oid(cte->ctecolcollations, i),
0);
tle = makeTargetEntry((Expr *) var, i + 1, strVal(list_nth(cte->ctecolnames, i)), false);
tle->resorigtbl = list_nth_node(TargetEntry, rte2->subquery->targetList, i)->resorigtbl;
tle->resorigcol = list_nth_node(TargetEntry, rte2->subquery->targetList, i)->resorigcol;
newq2->targetList = lappend(newq2->targetList, tle);
}
if (cte->search_clause)
{
Expr *texpr;
if (cte->search_clause->search_breadth_first)
{
FieldSelect *fs;
FuncExpr *fexpr;
/*
* ROW(sqc.depth + 1, cols)
*/
search_col_rowexpr = copyObject(search_col_rowexpr);
fs = makeNode(FieldSelect);
fs->arg = (Expr *) makeVar(1, sqc_attno, RECORDOID, -1, 0, 0);
fs->fieldnum = 1;
fs->resulttype = INT8OID;
fs->resulttypmod = -1;
fexpr = makeFuncExpr(F_INT8INC, INT8OID, list_make1(fs), InvalidOid, InvalidOid, COERCE_EXPLICIT_CALL);
linitial(search_col_rowexpr->args) = fexpr;
texpr = (Expr *) search_col_rowexpr;
}
else
{
/*
* sqc || ARRAY[ROW(cols)]
*/
texpr = make_path_cat_expr(search_col_rowexpr, sqc_attno);
}
tle = makeTargetEntry(texpr,
list_length(newq2->targetList) + 1,
cte->search_clause->search_seq_column,
false);
newq2->targetList = lappend(newq2->targetList, tle);
}
if (cte->cycle_clause)
{
ScalarArrayOpExpr *saoe;
CaseExpr *caseexpr;
CaseWhen *casewhen;
/*
* CASE WHEN ROW(cols) = ANY (ARRAY[cpa]) THEN cmv ELSE cmd END
*/
saoe = makeNode(ScalarArrayOpExpr);
saoe->location = -1;
saoe->opno = RECORD_EQ_OP;
saoe->useOr = true;
saoe->args = list_make2(cycle_col_rowexpr,
makeVar(1, cpa_attno, RECORDARRAYOID, -1, 0, 0));
caseexpr = makeNode(CaseExpr);
caseexpr->location = -1;
caseexpr->casetype = cte->cycle_clause->cycle_mark_type;
caseexpr->casecollid = cte->cycle_clause->cycle_mark_collation;
casewhen = makeNode(CaseWhen);
casewhen->location = -1;
casewhen->expr = (Expr *) saoe;
casewhen->result = (Expr *) cte->cycle_clause->cycle_mark_value;
caseexpr->args = list_make1(casewhen);
caseexpr->defresult = (Expr *) cte->cycle_clause->cycle_mark_default;
tle = makeTargetEntry((Expr *) caseexpr,
list_length(newq2->targetList) + 1,
cte->cycle_clause->cycle_mark_column,
false);
newq2->targetList = lappend(newq2->targetList, tle);
/*
* cpa || ARRAY[ROW(cols)]
*/
tle = makeTargetEntry(make_path_cat_expr(cycle_col_rowexpr, cpa_attno),
list_length(newq2->targetList) + 1,
cte->cycle_clause->cycle_path_column,
false);
newq2->targetList = lappend(newq2->targetList, tle);
}
rte2->subquery = newq2;
if (cte->search_clause)
{
rte2->eref->colnames = lappend(rte2->eref->colnames, makeString(cte->search_clause->search_seq_column));
}
if (cte->cycle_clause)
{
rte2->eref->colnames = lappend(rte2->eref->colnames, makeString(cte->cycle_clause->cycle_mark_column));
rte2->eref->colnames = lappend(rte2->eref->colnames, makeString(cte->cycle_clause->cycle_path_column));
}
/*
* Add the additional columns to the SetOperationStmt
*/
if (cte->search_clause)
{
sos->colTypes = lappend_oid(sos->colTypes, search_seq_type);
sos->colTypmods = lappend_int(sos->colTypmods, -1);
sos->colCollations = lappend_oid(sos->colCollations, InvalidOid);
if (!sos->all)
sos->groupClauses = lappend(sos->groupClauses,
makeSortGroupClauseForSetOp(search_seq_type, true));
}
if (cte->cycle_clause)
{
sos->colTypes = lappend_oid(sos->colTypes, cte->cycle_clause->cycle_mark_type);
sos->colTypmods = lappend_int(sos->colTypmods, cte->cycle_clause->cycle_mark_typmod);
sos->colCollations = lappend_oid(sos->colCollations, cte->cycle_clause->cycle_mark_collation);
if (!sos->all)
sos->groupClauses = lappend(sos->groupClauses,
makeSortGroupClauseForSetOp(cte->cycle_clause->cycle_mark_type, true));
sos->colTypes = lappend_oid(sos->colTypes, RECORDARRAYOID);
sos->colTypmods = lappend_int(sos->colTypmods, -1);
sos->colCollations = lappend_oid(sos->colCollations, InvalidOid);
if (!sos->all)
sos->groupClauses = lappend(sos->groupClauses,
makeSortGroupClauseForSetOp(RECORDARRAYOID, true));
}
/*
* Add the additional columns to the CTE query's target list
*/
if (cte->search_clause)
{
ctequery->targetList = lappend(ctequery->targetList,
makeTargetEntry((Expr *) makeVar(1, sqc_attno,
search_seq_type, -1, InvalidOid, 0),
list_length(ctequery->targetList) + 1,
cte->search_clause->search_seq_column,
false));
}
if (cte->cycle_clause)
{
ctequery->targetList = lappend(ctequery->targetList,
makeTargetEntry((Expr *) makeVar(1, cmc_attno,
cte->cycle_clause->cycle_mark_type,
cte->cycle_clause->cycle_mark_typmod,
cte->cycle_clause->cycle_mark_collation, 0),
list_length(ctequery->targetList) + 1,
cte->cycle_clause->cycle_mark_column,
false));
ctequery->targetList = lappend(ctequery->targetList,
makeTargetEntry((Expr *) makeVar(1, cpa_attno,
RECORDARRAYOID, -1, InvalidOid, 0),
list_length(ctequery->targetList) + 1,
cte->cycle_clause->cycle_path_column,
false));
}
/*
* Add the additional columns to the CTE's output columns
*/
cte->ctecolnames = ewcl;
if (cte->search_clause)
{
cte->ctecoltypes = lappend_oid(cte->ctecoltypes, search_seq_type);
cte->ctecoltypmods = lappend_int(cte->ctecoltypmods, -1);
cte->ctecolcollations = lappend_oid(cte->ctecolcollations, InvalidOid);
}
if (cte->cycle_clause)
{
cte->ctecoltypes = lappend_oid(cte->ctecoltypes, cte->cycle_clause->cycle_mark_type);
cte->ctecoltypmods = lappend_int(cte->ctecoltypmods, cte->cycle_clause->cycle_mark_typmod);
cte->ctecolcollations = lappend_oid(cte->ctecolcollations, cte->cycle_clause->cycle_mark_collation);
cte->ctecoltypes = lappend_oid(cte->ctecoltypes, RECORDARRAYOID);
cte->ctecoltypmods = lappend_int(cte->ctecoltypmods, -1);
cte->ctecolcollations = lappend_oid(cte->ctecolcollations, InvalidOid);
}
return cte;
}