postgresql/contrib/tsearch2/ts_stat.c
Tom Lane 234a02b2a8 Replace direct assignments to VARATT_SIZEP(x) with SET_VARSIZE(x, len).
Get rid of VARATT_SIZE and VARATT_DATA, which were simply redundant with
VARSIZE and VARDATA, and as a consequence almost no code was using the
longer names.  Rename the length fields of struct varlena and various
derived structures to catch anyplace that was accessing them directly;
and clean up various places so caught.  In itself this patch doesn't
change any behavior at all, but it is necessary infrastructure if we hope
to play any games with the representation of varlena headers.
Greg Stark and Tom Lane
2007-02-27 23:48:10 +00:00

568 lines
12 KiB
C

/*
* stat functions
*/
#include "tsvector.h"
#include "ts_stat.h"
#include "funcapi.h"
#include "catalog/pg_type.h"
#include "executor/spi.h"
#include "common.h"
#include "ts_locale.h"
PG_FUNCTION_INFO_V1(tsstat_in);
Datum tsstat_in(PG_FUNCTION_ARGS);
Datum
tsstat_in(PG_FUNCTION_ARGS)
{
tsstat *stat = palloc(STATHDRSIZE);
SET_VARSIZE(stat, STATHDRSIZE);
stat->size = 0;
stat->weight = 0;
PG_RETURN_POINTER(stat);
}
PG_FUNCTION_INFO_V1(tsstat_out);
Datum tsstat_out(PG_FUNCTION_ARGS);
Datum
tsstat_out(PG_FUNCTION_ARGS)
{
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("tsstat_out not implemented")));
PG_RETURN_NULL();
}
static int
check_weight(tsvector * txt, WordEntry * wptr, int8 weight)
{
int len = POSDATALEN(txt, wptr);
int num = 0;
WordEntryPos *ptr = POSDATAPTR(txt, wptr);
while (len--)
{
if (weight & (1 << WEP_GETWEIGHT(*ptr)))
num++;
ptr++;
}
return num;
}
static WordEntry **
SEI_realloc(WordEntry ** in, uint32 *len)
{
if (*len == 0 || in == NULL)
{
*len = 8;
in = palloc(sizeof(WordEntry *) * (*len));
}
else
{
*len *= 2;
in = repalloc(in, sizeof(WordEntry *) * (*len));
}
return in;
}
static int
compareStatWord(StatEntry * a, WordEntry * b, tsstat * stat, tsvector * txt)
{
if (a->len == b->len)
return strncmp(
STATSTRPTR(stat) + a->pos,
STRPTR(txt) + b->pos,
a->len
);
return (a->len > b->len) ? 1 : -1;
}
static tsstat *
formstat(tsstat * stat, tsvector * txt, WordEntry ** entry, uint32 len)
{
tsstat *newstat;
uint32 totallen,
nentry;
uint32 slen = 0;
WordEntry **ptr = entry;
char *curptr;
StatEntry *sptr,
*nptr;
while (ptr - entry < len)
{
slen += (*ptr)->len;
ptr++;
}
nentry = stat->size + len;
slen += STATSTRSIZE(stat);
totallen = CALCSTATSIZE(nentry, slen);
newstat = palloc(totallen);
SET_VARSIZE(newstat, totallen);
newstat->weight = stat->weight;
newstat->size = nentry;
memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat));
curptr = STATSTRPTR(newstat) + STATSTRSIZE(stat);
ptr = entry;
sptr = STATPTR(stat);
nptr = STATPTR(newstat);
if (len == 1)
{
StatEntry *StopLow = STATPTR(stat);
StatEntry *StopHigh = (StatEntry *) STATSTRPTR(stat);
while (StopLow < StopHigh)
{
sptr = StopLow + (StopHigh - StopLow) / 2;
if (compareStatWord(sptr, *ptr, stat, txt) < 0)
StopLow = sptr + 1;
else
StopHigh = sptr;
}
nptr = STATPTR(newstat) + (StopLow - STATPTR(stat));
memcpy(STATPTR(newstat), STATPTR(stat), sizeof(StatEntry) * (StopLow - STATPTR(stat)));
if ((*ptr)->haspos)
nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
else
nptr->nentry = 1;
nptr->ndoc = 1;
nptr->len = (*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
memcpy(nptr + 1, StopLow, sizeof(StatEntry) * (((StatEntry *) STATSTRPTR(stat)) - StopLow));
}
else
{
while (sptr - STATPTR(stat) < stat->size && ptr - entry < len)
{
if (compareStatWord(sptr, *ptr, stat, txt) < 0)
{
memcpy(nptr, sptr, sizeof(StatEntry));
sptr++;
}
else
{
if ((*ptr)->haspos)
nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
else
nptr->nentry = 1;
nptr->ndoc = 1;
nptr->len = (*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
curptr += nptr->len;
ptr++;
}
nptr++;
}
memcpy(nptr, sptr, sizeof(StatEntry) * (stat->size - (sptr - STATPTR(stat))));
while (ptr - entry < len)
{
if ((*ptr)->haspos)
nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
else
nptr->nentry = 1;
nptr->ndoc = 1;
nptr->len = (*ptr)->len;
memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
nptr->pos = curptr - STATSTRPTR(newstat);
curptr += nptr->len;
ptr++;
nptr++;
}
}
return newstat;
}
PG_FUNCTION_INFO_V1(ts_accum);
Datum ts_accum(PG_FUNCTION_ARGS);
Datum
ts_accum(PG_FUNCTION_ARGS)
{
tsstat *newstat,
*stat = (tsstat *) PG_GETARG_POINTER(0);
tsvector *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
WordEntry **newentry = NULL;
uint32 len = 0,
cur = 0;
StatEntry *sptr;
WordEntry *wptr;
int n = 0;
if (stat == NULL || PG_ARGISNULL(0))
{ /* Init in first */
stat = palloc(STATHDRSIZE);
SET_VARSIZE(stat, STATHDRSIZE);
stat->size = 0;
stat->weight = 0;
}
/* simple check of correctness */
if (txt == NULL || PG_ARGISNULL(1) || txt->size == 0)
{
PG_FREE_IF_COPY(txt, 1);
PG_RETURN_POINTER(stat);
}
sptr = STATPTR(stat);
wptr = ARRPTR(txt);
if (stat->size < 100 * txt->size)
{ /* merge */
while (sptr - STATPTR(stat) < stat->size && wptr - ARRPTR(txt) < txt->size)
{
int cmp = compareStatWord(sptr, wptr, stat, txt);
if (cmp < 0)
sptr++;
else if (cmp == 0)
{
if (stat->weight == 0)
{
sptr->ndoc++;
sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
}
else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
{
sptr->ndoc++;
sptr->nentry += n;
}
sptr++;
wptr++;
}
else
{
if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
{
if (cur == len)
newentry = SEI_realloc(newentry, &len);
newentry[cur] = wptr;
cur++;
}
wptr++;
}
}
while (wptr - ARRPTR(txt) < txt->size)
{
if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
{
if (cur == len)
newentry = SEI_realloc(newentry, &len);
newentry[cur] = wptr;
cur++;
}
wptr++;
}
}
else
{ /* search */
while (wptr - ARRPTR(txt) < txt->size)
{
StatEntry *StopLow = STATPTR(stat);
StatEntry *StopHigh = (StatEntry *) STATSTRPTR(stat);
int cmp;
while (StopLow < StopHigh)
{
sptr = StopLow + (StopHigh - StopLow) / 2;
cmp = compareStatWord(sptr, wptr, stat, txt);
if (cmp == 0)
{
if (stat->weight == 0)
{
sptr->ndoc++;
sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
}
else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
{
sptr->ndoc++;
sptr->nentry += n;
}
break;
}
else if (cmp < 0)
StopLow = sptr + 1;
else
StopHigh = sptr;
}
if (StopLow >= StopHigh)
{ /* not found */
if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
{
if (cur == len)
newentry = SEI_realloc(newentry, &len);
newentry[cur] = wptr;
cur++;
}
}
wptr++;
}
}
if (cur == 0)
{ /* no new words */
PG_FREE_IF_COPY(txt, 1);
PG_RETURN_POINTER(stat);
}
newstat = formstat(stat, txt, newentry, cur);
pfree(newentry);
PG_FREE_IF_COPY(txt, 1);
/* pfree(stat); */
PG_RETURN_POINTER(newstat);
}
typedef struct
{
uint32 cur;
tsvector *stat;
} StatStorage;
static void
ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
tsstat * stat)
{
TupleDesc tupdesc;
MemoryContext oldcontext;
StatStorage *st;
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
st = palloc(sizeof(StatStorage));
st->cur = 0;
st->stat = palloc(VARSIZE(stat));
memcpy(st->stat, stat, VARSIZE(stat));
funcctx->user_fctx = (void *) st;
if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE)
elog(ERROR, "return type must be a row type");
tupdesc = CreateTupleDescCopy(tupdesc);
funcctx->attinmeta = TupleDescGetAttInMetadata(tupdesc);
MemoryContextSwitchTo(oldcontext);
}
static Datum
ts_process_call(FuncCallContext *funcctx)
{
StatStorage *st;
st = (StatStorage *) funcctx->user_fctx;
if (st->cur < st->stat->size)
{
Datum result;
char *values[3];
char ndoc[16];
char nentry[16];
StatEntry *entry = STATPTR(st->stat) + st->cur;
HeapTuple tuple;
values[1] = ndoc;
sprintf(ndoc, "%d", entry->ndoc);
values[2] = nentry;
sprintf(nentry, "%d", entry->nentry);
values[0] = palloc(entry->len + 1);
memcpy(values[0], STATSTRPTR(st->stat) + entry->pos, entry->len);
(values[0])[entry->len] = '\0';
tuple = BuildTupleFromCStrings(funcctx->attinmeta, values);
result = HeapTupleGetDatum(tuple);
pfree(values[0]);
st->cur++;
return result;
}
else
{
pfree(st->stat);
pfree(st);
}
return (Datum) 0;
}
PG_FUNCTION_INFO_V1(ts_accum_finish);
Datum ts_accum_finish(PG_FUNCTION_ARGS);
Datum
ts_accum_finish(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
Datum result;
if (SRF_IS_FIRSTCALL())
{
funcctx = SRF_FIRSTCALL_INIT();
ts_setup_firstcall(fcinfo, funcctx, (tsstat *) PG_GETARG_POINTER(0));
}
funcctx = SRF_PERCALL_SETUP();
if ((result = ts_process_call(funcctx)) != (Datum) 0)
SRF_RETURN_NEXT(funcctx, result);
SRF_RETURN_DONE(funcctx);
}
static Oid tiOid = InvalidOid;
static void
get_ti_Oid(void)
{
int ret;
bool isnull;
if ((ret = SPI_exec("select oid from pg_type where typname='tsvector'", 1)) < 0)
/* internal error */
elog(ERROR, "SPI_exec to get tsvector oid returns %d", ret);
if (SPI_processed < 1)
/* internal error */
elog(ERROR, "there is no tsvector type");
tiOid = DatumGetObjectId(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
if (tiOid == InvalidOid)
/* internal error */
elog(ERROR, "tsvector type has InvalidOid");
}
static tsstat *
ts_stat_sql(text *txt, text *ws)
{
char *query = text2char(txt);
int i;
tsstat *newstat,
*stat;
bool isnull;
Portal portal;
void *plan;
if (tiOid == InvalidOid)
get_ti_Oid();
if ((plan = SPI_prepare(query, 0, NULL)) == NULL)
/* internal error */
elog(ERROR, "SPI_prepare('%s') returns NULL", query);
if ((portal = SPI_cursor_open(NULL, plan, NULL, NULL, false)) == NULL)
/* internal error */
elog(ERROR, "SPI_cursor_open('%s') returns NULL", query);
SPI_cursor_fetch(portal, true, 100);
if (SPI_tuptable->tupdesc->natts != 1)
/* internal error */
elog(ERROR, "number of fields doesn't equal to 1");
if (SPI_gettypeid(SPI_tuptable->tupdesc, 1) != tiOid)
/* internal error */
elog(ERROR, "column isn't of tsvector type");
stat = palloc(STATHDRSIZE);
SET_VARSIZE(stat, STATHDRSIZE);
stat->size = 0;
stat->weight = 0;
if (ws)
{
char *buf;
buf = VARDATA(ws);
while (buf - VARDATA(ws) < VARSIZE(ws) - VARHDRSZ)
{
if (pg_mblen(buf) == 1)
{
switch (*buf)
{
case 'A':
case 'a':
stat->weight |= 1 << 3;
break;
case 'B':
case 'b':
stat->weight |= 1 << 2;
break;
case 'C':
case 'c':
stat->weight |= 1 << 1;
break;
case 'D':
case 'd':
stat->weight |= 1;
break;
default:
stat->weight |= 0;
}
}
buf += pg_mblen(buf);
}
}
while (SPI_processed > 0)
{
for (i = 0; i < SPI_processed; i++)
{
Datum data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
if (!isnull)
{
newstat = (tsstat *) DatumGetPointer(DirectFunctionCall2(
ts_accum,
PointerGetDatum(stat),
data
));
if (stat != newstat && stat)
pfree(stat);
stat = newstat;
}
}
SPI_freetuptable(SPI_tuptable);
SPI_cursor_fetch(portal, true, 100);
}
SPI_freetuptable(SPI_tuptable);
SPI_cursor_close(portal);
SPI_freeplan(plan);
pfree(query);
return stat;
}
PG_FUNCTION_INFO_V1(ts_stat);
Datum ts_stat(PG_FUNCTION_ARGS);
Datum
ts_stat(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
Datum result;
if (SRF_IS_FIRSTCALL())
{
tsstat *stat;
text *txt = PG_GETARG_TEXT_P(0);
text *ws = (PG_NARGS() > 1) ? PG_GETARG_TEXT_P(1) : NULL;
funcctx = SRF_FIRSTCALL_INIT();
SPI_connect();
stat = ts_stat_sql(txt, ws);
PG_FREE_IF_COPY(txt, 0);
if (PG_NARGS() > 1)
PG_FREE_IF_COPY(ws, 1);
ts_setup_firstcall(fcinfo, funcctx, stat);
SPI_finish();
}
funcctx = SRF_PERCALL_SETUP();
if ((result = ts_process_call(funcctx)) != (Datum) 0)
SRF_RETURN_NEXT(funcctx, result);
SRF_RETURN_DONE(funcctx);
}