Arrange for PreventTransactionChain to reject commands submitted as part

of a multi-statement simple-Query message.  This bug goes all the way
back, but unfortunately is not nearly so easy to fix in existing releases;
it is only the recent ProcessUtility API change that makes it fixable in
HEAD.  Per report from William Garrison.
This commit is contained in:
Tom Lane 2007-03-22 19:55:04 +00:00
parent 686956375a
commit 4f896dac17
2 changed files with 20 additions and 8 deletions

View File

@ -10,7 +10,7 @@
* *
* *
* IDENTIFICATION * IDENTIFICATION
* $PostgreSQL: pgsql/src/backend/access/transam/xact.c,v 1.237 2007/03/13 14:32:25 petere Exp $ * $PostgreSQL: pgsql/src/backend/access/transam/xact.c,v 1.238 2007/03/22 19:55:04 tgl Exp $
* *
*------------------------------------------------------------------------- *-------------------------------------------------------------------------
*/ */
@ -2503,9 +2503,10 @@ AbortCurrentTransaction(void)
* completes). Subtransactions are verboten too. * completes). Subtransactions are verboten too.
* *
* isTopLevel: passed down from ProcessUtility to determine whether we are * isTopLevel: passed down from ProcessUtility to determine whether we are
* inside a function. (We will always fail if this is false, but it's * inside a function or multi-query querystring. (We will always fail if
* convenient to centralize the check here instead of making callers do it.) * this is false, but it's convenient to centralize the check here instead of
* stmtType: statement type name, for error messages. * making callers do it.)
* stmtType: statement type name, for error messages.
*/ */
void void
PreventTransactionChain(bool isTopLevel, const char *stmtType) PreventTransactionChain(bool isTopLevel, const char *stmtType)
@ -2537,7 +2538,8 @@ PreventTransactionChain(bool isTopLevel, const char *stmtType)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_ACTIVE_SQL_TRANSACTION), (errcode(ERRCODE_ACTIVE_SQL_TRANSACTION),
/* translator: %s represents an SQL statement name */ /* translator: %s represents an SQL statement name */
errmsg("%s cannot be executed from a function", stmtType))); errmsg("%s cannot be executed from a function or multi-command string",
stmtType)));
/* If we got past IsTransactionBlock test, should be in default state */ /* If we got past IsTransactionBlock test, should be in default state */
if (CurrentTransactionState->blockState != TBLOCK_DEFAULT && if (CurrentTransactionState->blockState != TBLOCK_DEFAULT &&

View File

@ -8,7 +8,7 @@
* *
* *
* IDENTIFICATION * IDENTIFICATION
* $PostgreSQL: pgsql/src/backend/tcop/postgres.c,v 1.528 2007/03/13 00:33:42 tgl Exp $ * $PostgreSQL: pgsql/src/backend/tcop/postgres.c,v 1.529 2007/03/22 19:55:04 tgl Exp $
* *
* NOTES * NOTES
* this is the "main" module of the postgres backend and * this is the "main" module of the postgres backend and
@ -765,6 +765,7 @@ exec_simple_query(const char *query_string)
ListCell *parsetree_item; ListCell *parsetree_item;
bool save_log_statement_stats = log_statement_stats; bool save_log_statement_stats = log_statement_stats;
bool was_logged = false; bool was_logged = false;
bool isTopLevel;
char msec_str[32]; char msec_str[32];
/* /*
@ -824,6 +825,15 @@ exec_simple_query(const char *query_string)
*/ */
MemoryContextSwitchTo(oldcontext); MemoryContextSwitchTo(oldcontext);
/*
* We'll tell PortalRun it's a top-level command iff there's exactly
* one raw parsetree. If more than one, it's effectively a transaction
* block and we want PreventTransactionChain to reject unsafe commands.
* (Note: we're assuming that query rewrite cannot add commands that are
* significant to PreventTransactionChain.)
*/
isTopLevel = (list_length(parsetree_list) == 1);
/* /*
* Run through the raw parsetree(s) and process each one. * Run through the raw parsetree(s) and process each one.
*/ */
@ -944,7 +954,7 @@ exec_simple_query(const char *query_string)
*/ */
(void) PortalRun(portal, (void) PortalRun(portal,
FETCH_ALL, FETCH_ALL,
true, /* top level */ isTopLevel,
receiver, receiver,
receiver, receiver,
completionTag); completionTag);
@ -1810,7 +1820,7 @@ exec_execute_message(const char *portal_name, long max_rows)
completed = PortalRun(portal, completed = PortalRun(portal,
max_rows, max_rows,
true, /* top level */ true, /* always top level */
receiver, receiver,
receiver, receiver,
completionTag); completionTag);