diff --git a/src/bin/psql/common.c b/src/bin/psql/common.c index 2cb2e9bb3b..2b67a439da 100644 --- a/src/bin/psql/common.c +++ b/src/bin/psql/common.c @@ -107,6 +107,65 @@ setQFout(const char *fname) } +/* + * Variable-fetching callback for flex lexer + * + * If the specified variable exists, return its value as a string (malloc'd + * and expected to be freed by the caller); else return NULL. + * + * If "escape" is true, return the value suitably quoted and escaped, + * as an identifier or string literal depending on "as_ident". + * (Failure in escaping should lead to returning NULL.) + */ +char * +psql_get_variable(const char *varname, bool escape, bool as_ident) +{ + char *result; + const char *value; + + value = GetVariable(pset.vars, varname); + if (!value) + return NULL; + + if (escape) + { + char *escaped_value; + + if (!pset.db) + { + psql_error("can't escape without active connection\n"); + return NULL; + } + + if (as_ident) + escaped_value = + PQescapeIdentifier(pset.db, value, strlen(value)); + else + escaped_value = + PQescapeLiteral(pset.db, value, strlen(value)); + + if (escaped_value == NULL) + { + const char *error = PQerrorMessage(pset.db); + + psql_error("%s", error); + return NULL; + } + + /* + * Rather than complicate the lexer's API with a notion of which + * free() routine to use, just pay the price of an extra strdup(). + */ + result = pg_strdup(escaped_value); + PQfreemem(escaped_value); + } + else + result = pg_strdup(value); + + return result; +} + + /* * Error reporting for scripts. Errors should look like * psql:filename:lineno: message diff --git a/src/bin/psql/common.h b/src/bin/psql/common.h index ce7b93f9e5..ba4c5699b3 100644 --- a/src/bin/psql/common.h +++ b/src/bin/psql/common.h @@ -18,6 +18,8 @@ extern bool openQueryOutputFile(const char *fname, FILE **fout, bool *is_pipe); extern bool setQFout(const char *fname); +extern char *psql_get_variable(const char *varname, bool escape, bool as_ident); + extern void psql_error(const char *fmt,...) pg_attribute_printf(1, 2); extern void NoticeProcessor(void *arg, const char *message); diff --git a/src/bin/psql/mainloop.c b/src/bin/psql/mainloop.c index dadbd29397..bade35139b 100644 --- a/src/bin/psql/mainloop.c +++ b/src/bin/psql/mainloop.c @@ -8,7 +8,6 @@ #include "postgres_fe.h" #include "mainloop.h" - #include "command.h" #include "common.h" #include "input.h" @@ -17,6 +16,13 @@ #include "mb/pg_wchar.h" +/* callback functions for our flex lexer */ +const PsqlScanCallbacks psqlscan_callbacks = { + psql_get_variable, + psql_error +}; + + /* * Main processing loop for reading lines of input * and sending them to the backend. @@ -61,7 +67,7 @@ MainLoop(FILE *source) pset.stmt_lineno = 1; /* Create working state */ - scan_state = psql_scan_create(); + scan_state = psql_scan_create(&psqlscan_callbacks); query_buf = createPQExpBuffer(); previous_buf = createPQExpBuffer(); @@ -233,7 +239,8 @@ MainLoop(FILE *source) /* * Parse line, looking for command separators. */ - psql_scan_setup(scan_state, line, strlen(line)); + psql_scan_setup(scan_state, line, strlen(line), + pset.encoding, standard_strings()); success = true; line_saved_in_history = false; @@ -373,7 +380,8 @@ MainLoop(FILE *source) resetPQExpBuffer(query_buf); /* reset parsing state since we are rescanning whole line */ psql_scan_reset(scan_state); - psql_scan_setup(scan_state, line, strlen(line)); + psql_scan_setup(scan_state, line, strlen(line), + pset.encoding, standard_strings()); line_saved_in_history = false; prompt_status = PROMPT_READY; } diff --git a/src/bin/psql/mainloop.h b/src/bin/psql/mainloop.h index e6476ca7c6..5ee8dc7f63 100644 --- a/src/bin/psql/mainloop.h +++ b/src/bin/psql/mainloop.h @@ -8,6 +8,10 @@ #ifndef MAINLOOP_H #define MAINLOOP_H +#include "psqlscan.h" + +extern const PsqlScanCallbacks psqlscan_callbacks; + extern int MainLoop(FILE *source); #endif /* MAINLOOP_H */ diff --git a/src/bin/psql/psqlscan.h b/src/bin/psql/psqlscan.h index 674ba69eda..82c66dcdf9 100644 --- a/src/bin/psql/psqlscan.h +++ b/src/bin/psql/psqlscan.h @@ -36,12 +36,23 @@ enum slash_option_type OT_NO_EVAL /* no expansion of backticks or variables */ }; +/* Callback functions to be used by the lexer */ +typedef struct PsqlScanCallbacks +{ + /* Fetch value of a variable, as a pfree'able string; NULL if unknown */ + /* This pointer can be NULL if no variable substitution is wanted */ + char *(*get_variable) (const char *varname, bool escape, bool as_ident); + /* Print an error message someplace appropriate */ + void (*write_error) (const char *fmt,...) pg_attribute_printf(1, 2); +} PsqlScanCallbacks; -extern PsqlScanState psql_scan_create(void); + +extern PsqlScanState psql_scan_create(const PsqlScanCallbacks *callbacks); extern void psql_scan_destroy(PsqlScanState state); extern void psql_scan_setup(PsqlScanState state, - const char *line, int line_len); + const char *line, int line_len, + int encoding, bool std_strings); extern void psql_scan_finish(PsqlScanState state); extern PsqlScanResult psql_scan(PsqlScanState state, diff --git a/src/bin/psql/psqlscan.l b/src/bin/psql/psqlscan.l index bbe0172737..b741ab8fc5 100644 --- a/src/bin/psql/psqlscan.l +++ b/src/bin/psql/psqlscan.l @@ -2,7 +2,7 @@ /*------------------------------------------------------------------------- * * psqlscan.l - * lexical scanner for psql + * lexical scanner for psql (and other frontend programs) * * This code is mainly needed to determine where the end of a SQL statement * is: we are looking for semicolons that are not within quotes, comments, @@ -41,11 +41,7 @@ #include "psqlscan.h" -#include - -#include "common.h" -#include "settings.h" -#include "variables.h" +#include "libpq-fe.h" /* @@ -83,6 +79,7 @@ typedef struct PsqlScanStateData /* safe_encoding, curline, refline are used by emit() to replace FFs */ int encoding; /* encoding being used now */ bool safe_encoding; /* is current encoding "safe"? */ + bool std_strings; /* are string literals standard? */ const char *curline; /* actual flex input string for cur buf */ const char *refline; /* original data for cur buffer */ @@ -94,6 +91,11 @@ typedef struct PsqlScanStateData int paren_depth; /* depth of nesting in parentheses */ int xcdepth; /* depth of nesting in slash-star comments */ char *dolqstart; /* current $foo$ quote start string */ + + /* + * Callback functions provided by the program making use of the lexer. + */ + const PsqlScanCallbacks *callbacks; } PsqlScanStateData; static PsqlScanState cur_state; /* current state while active */ @@ -135,6 +137,7 @@ static void escape_variable(bool as_ident); %option nounput %option noyywrap %option warn +%option prefix="psql_yy" /* * All of the following definitions and rules should exactly match @@ -508,7 +511,7 @@ other . } {xqstart} { - if (standard_strings()) + if (cur_state->std_strings) BEGIN(xq); else BEGIN(xe); @@ -737,10 +740,15 @@ other . :{variable_char}+ { /* Possible psql variable substitution */ char *varname; - const char *value; + char *value; varname = extract_substring(yytext + 1, yyleng - 1); - value = GetVariable(pset.vars, varname); + if (cur_state->callbacks->get_variable) + value = cur_state->callbacks->get_variable(varname, + false, + false); + else + value = NULL; if (value) { @@ -748,8 +756,8 @@ other . if (var_is_current_source(cur_state, varname)) { /* Recursive expansion --- don't go there */ - psql_error("skipping recursive expansion of variable \"%s\"\n", - varname); + cur_state->callbacks->write_error("skipping recursive expansion of variable \"%s\"\n", + varname); /* Instead copy the string as is */ ECHO; } @@ -759,6 +767,7 @@ other . push_new_buffer(value, varname); /* yy_scan_string already made buffer active */ } + free(value); } else { @@ -1026,15 +1035,18 @@ other . :{variable_char}+ { /* Possible psql variable substitution */ - if (option_type == OT_NO_EVAL) + if (option_type == OT_NO_EVAL || + cur_state->callbacks->get_variable == NULL) ECHO; else { char *varname; - const char *value; + char *value; varname = extract_substring(yytext + 1, yyleng - 1); - value = GetVariable(pset.vars, varname); + value = cur_state->callbacks->get_variable(varname, + false, + false); free(varname); /* @@ -1045,7 +1057,10 @@ other . * Note that we needn't guard against recursion here. */ if (value) + { appendPQExpBufferStr(output_buf, value); + free(value); + } else ECHO; @@ -1191,14 +1206,20 @@ other . /* * Create a lexer working state struct. + * + * callbacks is a struct of function pointers that encapsulate some + * behavior we need from the surrounding program. This struct must + * remain valid for the lifespan of the PsqlScanState. */ PsqlScanState -psql_scan_create(void) +psql_scan_create(const PsqlScanCallbacks *callbacks) { PsqlScanState state; state = (PsqlScanStateData *) pg_malloc0(sizeof(PsqlScanStateData)); + state->callbacks = callbacks; + psql_scan_reset(state); return state; @@ -1225,18 +1246,25 @@ psql_scan_destroy(PsqlScanState state) * be called when scanning is complete. Note that the lexer retains * a pointer to the storage at *line --- this string must not be altered * or freed until after psql_scan_finish is called. + * + * encoding is the libpq identifier for the character encoding in use, + * and std_strings says whether standard_conforming_strings is on. */ void psql_scan_setup(PsqlScanState state, - const char *line, int line_len) + const char *line, int line_len, + int encoding, bool std_strings) { /* Mustn't be scanning already */ Assert(state->scanbufhandle == NULL); Assert(state->buffer_stack == NULL); /* Do we need to hack the character set encoding? */ - state->encoding = pset.encoding; - state->safe_encoding = pg_valid_server_encoding_id(state->encoding); + state->encoding = encoding; + state->safe_encoding = pg_valid_server_encoding_id(encoding); + + /* Save standard-strings flag as well */ + state->std_strings = std_strings; /* needed for prepare_buffer */ cur_state = state; @@ -1615,7 +1643,7 @@ psql_scan_slash_option(PsqlScanState state, { if (!inquotes && type == OT_SQLID) *cp = pg_tolower((unsigned char) *cp); - cp += PQmblen(cp, pset.encoding); + cp += PQmblen(cp, state->encoding); } } } @@ -1936,53 +1964,31 @@ extract_substring(const char *txt, int len) * If the variable name is found, escape its value using the appropriate * quoting method and emit the value to output_buf. (Since the result is * surely quoted, there is never any reason to rescan it.) If we don't - * find the variable or the escaping function fails, emit the token as-is. + * find the variable or escaping fails, emit the token as-is. */ static void escape_variable(bool as_ident) { char *varname; - const char *value; + char *value; /* Variable lookup. */ varname = extract_substring(yytext + 2, yyleng - 3); - value = GetVariable(pset.vars, varname); + if (cur_state->callbacks->get_variable) + value = cur_state->callbacks->get_variable(varname, true, as_ident); + else + value = NULL; free(varname); - /* Escaping. */ if (value) { - if (!pset.db) - psql_error("can't escape without active connection\n"); - else - { - char *escaped_value; - - if (as_ident) - escaped_value = - PQescapeIdentifier(pset.db, value, strlen(value)); - else - escaped_value = - PQescapeLiteral(pset.db, value, strlen(value)); - - if (escaped_value == NULL) - { - const char *error = PQerrorMessage(pset.db); - - psql_error("%s", error); - } - else - { - appendPQExpBufferStr(output_buf, escaped_value); - PQfreemem(escaped_value); - return; - } - } + /* Emit the suitably-escaped value */ + appendPQExpBufferStr(output_buf, value); + free(value); + } + else + { + /* Emit original token as-is */ + emit(yytext, yyleng); } - - /* - * If we reach this point, some kind of error has occurred. Emit the - * original text into the output buffer. - */ - emit(yytext, yyleng); } diff --git a/src/bin/psql/startup.c b/src/bin/psql/startup.c index 6916f6f461..4bb3fdc595 100644 --- a/src/bin/psql/startup.c +++ b/src/bin/psql/startup.c @@ -336,10 +336,10 @@ main(int argc, char *argv[]) if (pset.echo == PSQL_ECHO_ALL) puts(cell->val); - scan_state = psql_scan_create(); + scan_state = psql_scan_create(&psqlscan_callbacks); psql_scan_setup(scan_state, - cell->val, - strlen(cell->val)); + cell->val, strlen(cell->val), + pset.encoding, standard_strings()); successResult = HandleSlashCmds(scan_state, NULL) != PSQL_CMD_ERROR ? EXIT_SUCCESS : EXIT_FAILURE;