#include "postgres.h" #include "executor/spi.h" #include "query_util.h" MemoryContext AggregateContext = NULL; static int addone(int * counters, int last, int total) { counters[last]++; if ( counters[last]>=total ) { if (last==0) return 0; if ( addone( counters, last-1, total-1 ) == 0 ) return 0; counters[last] = counters[last-1]+1; } return 1; } static QTNode * findeq(QTNode *node, QTNode *ex, MemoryType memtype, QTNode *subs, bool *isfind) { if ( (node->sign & ex->sign) != ex->sign || node->valnode->type != ex->valnode->type || node->valnode->val != ex->valnode->val ) return node; if ( node->flags & QTN_NOCHANGE ) return node; if ( node->valnode->type==OPR ) { if ( node->nchild == ex->nchild ) { if ( QTNEq( node, ex ) ) { QTNFree( node ); if ( subs ) { node = QTNCopy( subs, memtype ); node->flags |= QTN_NOCHANGE; } else node = NULL; *isfind = true; } } else if ( node->nchild > ex->nchild ) { int *counters = (int*)palloc( sizeof(int) * node->nchild ); int i; QTNode *tnode = (QTNode*)MEMALLOC( memtype, sizeof(QTNode) ); memset(tnode, 0, sizeof(QTNode)); tnode->child = (QTNode**)MEMALLOC( memtype, sizeof(QTNode*) * ex->nchild ); tnode->nchild = ex->nchild; tnode->valnode = (ITEM*)MEMALLOC( memtype, sizeof(ITEM) ); *(tnode->valnode) = *(ex->valnode); for(i=0;inchild;i++) counters[i]=i; do { tnode->sign=0; for(i=0;inchild;i++) { tnode->child[i] = node->child[ counters[i] ]; tnode->sign |= tnode->child[i]->sign; } if ( QTNEq( tnode, ex ) ) { int j=0; MEMFREE( memtype, tnode->valnode ); MEMFREE( memtype, tnode->child ); MEMFREE( memtype, tnode ); if ( subs ) { tnode = QTNCopy( subs, memtype ); tnode->flags = QTN_NOCHANGE | QTN_NEEDFREE; } else tnode = NULL; node->child[ counters[0] ] = tnode; for(i=1;inchild;i++) node->child[ counters[i] ] = NULL; for(i=0;inchild;i++) { if ( node->child[i] ) { node->child[j] = node->child[i]; j++; } } node->nchild = j; *isfind = true; break; } } while (addone(counters,ex->nchild-1,node->nchild)); if ( tnode && (tnode->flags & QTN_NOCHANGE) == 0 ) { MEMFREE( memtype, tnode->valnode ); MEMFREE( memtype, tnode->child ); MEMFREE( memtype, tnode ); } else QTNSort( node ); pfree( counters ); } } else if ( QTNEq( node, ex ) ) { QTNFree( node ); if ( subs ) { node = QTNCopy( subs, memtype ); node->flags |= QTN_NOCHANGE; } else { node = NULL; } *isfind = true; } return node; } static QTNode * dofindsubquery( QTNode *root, QTNode *ex, MemoryType memtype, QTNode *subs, bool *isfind ) { root = findeq( root, ex, memtype, subs, isfind ); if ( root && (root->flags & QTN_NOCHANGE) == 0 && root->valnode->type==OPR) { int i; for(i=0;inchild;i++) root->child[i] = dofindsubquery( root->child[i], ex, memtype, subs, isfind ); } return root; } static QTNode * dropvoidsubtree( QTNode *root ) { if ( !root ) return NULL; if ( root->valnode->type==OPR ) { int i,j=0; for(i=0;inchild;i++) { if ( root->child[i] ) { root->child[j] = root->child[i]; j++; } } root->nchild = j; if ( root->valnode->val == (int4)'!' && root->nchild==0 ) { QTNFree(root); root=NULL; } else if ( root->nchild==1 ) { QTNode *nroot = root->child[0]; pfree(root); root = nroot; } } return root; } static QTNode * findsubquery( QTNode *root, QTNode *ex, MemoryType memtype, QTNode *subs, bool *isfind ) { bool DidFind = false; root = dofindsubquery( root, ex, memtype, subs, &DidFind ); if ( !subs && DidFind ) root = dropvoidsubtree( root ); if ( isfind ) *isfind = DidFind; return root; } static Oid tsqOid = InvalidOid; static void get_tsq_Oid(void) { int ret; bool isnull; if ((ret = SPI_exec("select oid from pg_type where typname='tsquery'", 1)) < 0) /* internal error */ elog(ERROR, "SPI_exec to get tsquery oid returns %d", ret); if (SPI_processed < 0) /* internal error */ elog(ERROR, "There is no tsvector type"); tsqOid = DatumGetObjectId(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull)); if (tsqOid == InvalidOid) /* internal error */ elog(ERROR, "tsquery type has InvalidOid"); } PG_FUNCTION_INFO_V1(tsquery_rewrite); PG_FUNCTION_INFO_V1(rewrite_accum); Datum rewrite_accum(PG_FUNCTION_ARGS); Datum rewrite_accum(PG_FUNCTION_ARGS) { QUERYTYPE *acc = (QUERYTYPE *) PG_GETARG_POINTER(0); ArrayType *qa = (ArrayType *) DatumGetPointer(PG_DETOAST_DATUM_COPY(PG_GETARG_DATUM(1))); QUERYTYPE *q; QTNode *qex, *subs = NULL, *acctree; bool isfind = false; AggregateContext = ((AggState *) fcinfo->context)->aggcontext; if (acc == NULL || PG_ARGISNULL(0)) { acc = (QUERYTYPE*)MEMALLOC( AggMemory, sizeof(QUERYTYPE) ); acc->len = HDRSIZEQT; acc->size = 0; } if ( qa == NULL || PG_ARGISNULL(1) ) { PG_FREE_IF_COPY( qa, 1 ); PG_RETURN_POINTER( acc ); } if ( ARR_NDIM(qa) != 1 ) elog(ERROR, "array must be one-dimensional, not %d dimension", ARR_NDIM(qa)); if ( ArrayGetNItems( ARR_NDIM(qa), ARR_DIMS(qa)) != 3 ) elog(ERROR, "array should have only three elements"); if (tsqOid == InvalidOid) { SPI_connect(); get_tsq_Oid(); SPI_finish(); } if (ARR_ELEMTYPE(qa) != tsqOid) elog(ERROR, "array should contain tsquery type"); q = (QUERYTYPE*)ARR_DATA_PTR(qa); if ( q->size == 0 ) PG_RETURN_POINTER( acc ); if ( !acc->size ) { if ( acc->len > HDRSIZEQT ) PG_RETURN_POINTER( acc ); else acctree = QT2QTN( GETQUERY(q), GETOPERAND(q) ); } else acctree = QT2QTN( GETQUERY(acc), GETOPERAND(acc) ); QTNTernary( acctree ); QTNSort( acctree ); q = (QUERYTYPE*)( ((char*)ARR_DATA_PTR(qa)) + MAXALIGN( q->len ) ); if ( q->size == 0 ) PG_RETURN_POINTER( acc ); qex = QT2QTN( GETQUERY(q), GETOPERAND(q) ); QTNTernary( qex ); QTNSort( qex ); q = (QUERYTYPE*)( ((char*)q) + MAXALIGN( q->len ) ); if ( q->size ) subs = QT2QTN( GETQUERY(q), GETOPERAND(q) ); acctree = findsubquery( acctree, qex, PlainMemory, subs, &isfind ); if ( isfind || !acc->size ) { /* pfree( acc ); do not pfree(p), because nodeAgg.c will */ if ( acctree ) { QTNBinary( acctree ); acc = QTN2QT( acctree, AggMemory ); } else { acc = (QUERYTYPE*)MEMALLOC( AggMemory, HDRSIZEQT*2 ); acc->len = HDRSIZEQT * 2; acc->size = 0; } } QTNFree( qex ); QTNFree( subs ); QTNFree( acctree ); PG_RETURN_POINTER( acc ); } PG_FUNCTION_INFO_V1(rewrite_finish); Datum rewrite_finish(PG_FUNCTION_ARGS); Datum rewrite_finish(PG_FUNCTION_ARGS) { QUERYTYPE *acc = (QUERYTYPE *) PG_GETARG_POINTER(0); QUERYTYPE *rewrited; if (acc == NULL || PG_ARGISNULL(0) || acc->size == 0 ) { acc = (QUERYTYPE*)palloc(sizeof(QUERYTYPE)); acc->len = HDRSIZEQT; acc->size = 0; } rewrited = (QUERYTYPE*) palloc( acc->len ); memcpy( rewrited, acc, acc->len ); pfree( acc ); PG_RETURN_POINTER(rewrited); } Datum tsquery_rewrite(PG_FUNCTION_ARGS); Datum tsquery_rewrite(PG_FUNCTION_ARGS) { QUERYTYPE *query = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM_COPY(PG_GETARG_DATUM(0))); text *in = PG_GETARG_TEXT_P(1); QUERYTYPE *rewrited = query; QTNode *tree; char *buf; void *plan; Portal portal; bool isnull; int i; if ( query->size == 0 ) { PG_FREE_IF_COPY(in, 1); PG_RETURN_POINTER( rewrited ); } tree = QT2QTN( GETQUERY(query), GETOPERAND(query) ); QTNTernary( tree ); QTNSort( tree ); buf = (char*)palloc( VARSIZE(in) ); memcpy(buf, VARDATA(in), VARSIZE(in) - VARHDRSZ); buf[ VARSIZE(in) - VARHDRSZ ] = '\0'; SPI_connect(); if (tsqOid == InvalidOid) get_tsq_Oid(); if ((plan = SPI_prepare(buf, 0, NULL)) == NULL) elog(ERROR, "SPI_prepare('%s') returns NULL", buf); if ((portal = SPI_cursor_open(NULL, plan, NULL, NULL, false)) == NULL) elog(ERROR, "SPI_cursor_open('%s') returns NULL", buf); SPI_cursor_fetch(portal, true, 100); if (SPI_tuptable->tupdesc->natts != 2) elog(ERROR, "number of fields doesn't equal to 2"); if (SPI_gettypeid(SPI_tuptable->tupdesc, 1) != tsqOid ) elog(ERROR, "column #1 isn't of tsquery type"); if (SPI_gettypeid(SPI_tuptable->tupdesc, 2) != tsqOid ) elog(ERROR, "column #2 isn't of tsquery type"); while (SPI_processed > 0 && tree ) { for (i = 0; i < SPI_processed && tree; i++) { Datum qdata = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull); Datum sdata; if ( isnull ) continue; sdata = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 2, &isnull); if (!isnull) { QUERYTYPE *qtex = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM(qdata)); QUERYTYPE *qtsubs = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM(sdata)); QTNode *qex, *qsubs = NULL; if (qtex->size == 0) { if ( qtex != (QUERYTYPE *) DatumGetPointer(qdata) ) pfree( qtex ); if ( qtsubs != (QUERYTYPE *) DatumGetPointer(sdata) ) pfree( qtsubs ); continue; } qex = QT2QTN( GETQUERY(qtex), GETOPERAND(qtex) ); QTNTernary( qex ); QTNSort( qex ); if ( qtsubs->size ) qsubs = QT2QTN( GETQUERY(qtsubs), GETOPERAND(qtsubs) ); tree = findsubquery( tree, qex, SPIMemory, qsubs, NULL ); QTNFree( qex ); if ( qtex != (QUERYTYPE *) DatumGetPointer(qdata) ) pfree( qtex ); QTNFree( qsubs ); if ( qtsubs != (QUERYTYPE *) DatumGetPointer(sdata) ) pfree( qtsubs ); } } SPI_freetuptable(SPI_tuptable); SPI_cursor_fetch(portal, true, 100); } SPI_freetuptable(SPI_tuptable); SPI_cursor_close(portal); SPI_freeplan(plan); SPI_finish(); if ( tree ) { QTNBinary( tree ); rewrited = QTN2QT( tree, PlainMemory ); QTNFree( tree ); PG_FREE_IF_COPY(query, 0); } else { rewrited->len = HDRSIZEQT; rewrited->size = 0; } pfree(buf); PG_FREE_IF_COPY(in, 1); PG_RETURN_POINTER( rewrited ); } PG_FUNCTION_INFO_V1(tsquery_rewrite_query); Datum tsquery_rewrite_query(PG_FUNCTION_ARGS); Datum tsquery_rewrite_query(PG_FUNCTION_ARGS) { QUERYTYPE *query = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM_COPY(PG_GETARG_DATUM(0))); QUERYTYPE *ex = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM(PG_GETARG_DATUM(1))); QUERYTYPE *subst = (QUERYTYPE *) DatumGetPointer(PG_DETOAST_DATUM(PG_GETARG_DATUM(2))); QUERYTYPE *rewrited = query; QTNode *tree, *qex, *subs = NULL; if ( query->size == 0 || ex->size == 0 ) { PG_FREE_IF_COPY(ex, 1); PG_FREE_IF_COPY(subst, 2); PG_RETURN_POINTER( rewrited ); } tree = QT2QTN( GETQUERY(query), GETOPERAND(query) ); QTNTernary( tree ); QTNSort( tree ); qex = QT2QTN( GETQUERY(ex), GETOPERAND(ex) ); QTNTernary( qex ); QTNSort( qex ); if ( subst->size ) subs = QT2QTN( GETQUERY(subst), GETOPERAND(subst) ); tree = findsubquery( tree, qex, PlainMemory, subs, NULL ); QTNFree( qex ); QTNFree( subs ); if ( !tree ) { rewrited->len = HDRSIZEQT; rewrited->size = 0; PG_FREE_IF_COPY(ex, 1); PG_FREE_IF_COPY(subst, 2); PG_RETURN_POINTER( rewrited ); } else { QTNBinary( tree ); rewrited = QTN2QT( tree, PlainMemory ); QTNFree( tree ); } PG_FREE_IF_COPY(query, 0); PG_FREE_IF_COPY(ex, 1); PG_FREE_IF_COPY(subst, 2); PG_RETURN_POINTER( rewrited ); }