diff options
Diffstat (limited to 'src/backend/parser/parse_cte.c')
-rw-r--r-- | src/backend/parser/parse_cte.c | 60 |
1 files changed, 48 insertions, 12 deletions
diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c index 4d3d33eb079..23b72b245b2 100644 --- a/src/backend/parser/parse_cte.c +++ b/src/backend/parser/parse_cte.c @@ -115,7 +115,7 @@ transformWithClause(ParseState *pstate, WithClause *withClause) * list. Check this right away so we needn't worry later. * * Also, tentatively mark each CTE as non-recursive, and initialize its - * reference count to zero. + * reference count to zero, and set pstate->p_hasModifyingCTE if needed. */ foreach(lc, withClause->ctes) { @@ -136,6 +136,16 @@ transformWithClause(ParseState *pstate, WithClause *withClause) cte->cterecursive = false; cte->cterefcount = 0; + + if (!IsA(cte->ctequery, SelectStmt)) + { + /* must be a data-modifying statement */ + Assert(IsA(cte->ctequery, InsertStmt) || + IsA(cte->ctequery, UpdateStmt) || + IsA(cte->ctequery, DeleteStmt)); + + pstate->p_hasModifyingCTE = true; + } } if (withClause->recursive) @@ -229,20 +239,20 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte) Query *query; /* Analysis not done already */ - Assert(IsA(cte->ctequery, SelectStmt)); + Assert(!IsA(cte->ctequery, Query)); query = parse_sub_analyze(cte->ctequery, pstate, cte, false); cte->ctequery = (Node *) query; /* - * Check that we got something reasonable. Many of these conditions are - * impossible given restrictions of the grammar, but check 'em anyway. - * (These are the same checks as in transformRangeSubselect.) + * Check that we got something reasonable. These first two cases should + * be prevented by the grammar. */ - if (!IsA(query, Query) || - query->commandType != CMD_SELECT || - query->utilityStmt != NULL) - elog(ERROR, "unexpected non-SELECT command in subquery in WITH"); + if (!IsA(query, Query)) + elog(ERROR, "unexpected non-Query statement in WITH"); + if (query->utilityStmt != NULL) + elog(ERROR, "unexpected utility statement in WITH"); + if (query->intoClause) ereport(ERROR, (errcode(ERRCODE_SYNTAX_ERROR), @@ -250,10 +260,28 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte) parser_errposition(pstate, exprLocation((Node *) query->intoClause)))); + /* + * We disallow data-modifying WITH except at the top level of a query, + * because it's not clear when such a modification should be executed. + */ + if (query->commandType != CMD_SELECT && + pstate->parentParseState != NULL) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("WITH clause containing a data-modifying statement must be at the top level"), + parser_errposition(pstate, cte->location))); + + /* + * CTE queries are always marked not canSetTag. (Currently this only + * matters for data-modifying statements, for which the flag will be + * propagated to the ModifyTable plan node.) + */ + query->canSetTag = false; + if (!cte->cterecursive) { /* Compute the output column names/types if not done yet */ - analyzeCTETargetList(pstate, cte, query->targetList); + analyzeCTETargetList(pstate, cte, GetCTETargetList(cte)); } else { @@ -273,7 +301,7 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte) lctypmod = list_head(cte->ctecoltypmods); lccoll = list_head(cte->ctecolcollations); varattno = 0; - foreach(lctlist, query->targetList) + foreach(lctlist, GetCTETargetList(cte)) { TargetEntry *te = (TargetEntry *) lfirst(lctlist); Node *texpr; @@ -613,12 +641,20 @@ checkWellFormedRecursion(CteState *cstate) CommonTableExpr *cte = cstate->items[i].cte; SelectStmt *stmt = (SelectStmt *) cte->ctequery; - Assert(IsA(stmt, SelectStmt)); /* not analyzed yet */ + Assert(!IsA(stmt, Query)); /* not analyzed yet */ /* Ignore items that weren't found to be recursive */ if (!cte->cterecursive) continue; + /* Must be a SELECT statement */ + if (!IsA(stmt, SelectStmt)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_RECURSION), + errmsg("recursive query \"%s\" must not contain data-modifying statements", + cte->ctename), + parser_errposition(cstate->pstate, cte->location))); + /* Must have top-level UNION */ if (stmt->op != SETOP_UNION) ereport(ERROR, |