diff options
Diffstat (limited to 'src/backend/parser/parse_agg.c')
-rw-r--r-- | src/backend/parser/parse_agg.c | 214 |
1 files changed, 149 insertions, 65 deletions
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index bee7d8346a3..bd095d05c0b 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -26,6 +26,7 @@ #include "parser/parse_clause.h" #include "parser/parse_coerce.h" #include "parser/parse_expr.h" +#include "parser/parse_relation.h" #include "parser/parsetree.h" #include "rewrite/rewriteManip.h" #include "utils/builtins.h" @@ -47,11 +48,12 @@ typedef struct bool hasJoinRTEs; List *groupClauses; List *groupClauseCommonVars; + List *gset_common; bool have_non_var_grouping; List **func_grouped_rels; int sublevels_up; bool in_agg_direct_args; -} check_ungrouped_columns_context; +} substitute_grouped_columns_context; static int check_agg_arguments(ParseState *pstate, List *directargs, @@ -59,17 +61,20 @@ static int check_agg_arguments(ParseState *pstate, Expr *filter); static bool check_agg_arguments_walker(Node *node, check_agg_arguments_context *context); -static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry, - List *groupClauses, List *groupClauseCommonVars, - bool have_non_var_grouping, - List **func_grouped_rels); -static bool check_ungrouped_columns_walker(Node *node, - check_ungrouped_columns_context *context); +static Node *substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry, + List *groupClauses, List *groupClauseCommonVars, + List *gset_common, + bool have_non_var_grouping, + List **func_grouped_rels); +static Node *substitute_grouped_columns_mutator(Node *node, + substitute_grouped_columns_context *context); static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry, List *groupClauses, bool hasJoinRTEs, bool have_non_var_grouping); static bool finalize_grouping_exprs_walker(Node *node, - check_ungrouped_columns_context *context); + substitute_grouped_columns_context *context); +static Var *buildGroupedVar(int attnum, Index ressortgroupref, + substitute_grouped_columns_context *context); static void check_agglevels_and_constraints(ParseState *pstate, Node *expr); static List *expand_groupingset_node(GroupingSet *gs); static Node *make_agg_arg(Oid argtype, Oid argcollation); @@ -1066,7 +1071,9 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc, /* * parseCheckAggregates - * Check for aggregates where they shouldn't be and improper grouping. + * Check for aggregates where they shouldn't be and improper grouping, and + * replace grouped variables in the targetlist and HAVING clause with Vars + * that reference the RTE_GROUP RTE. * This function should be called after the target list and qualifications * are finalized. * @@ -1156,7 +1163,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry) /* * Build a list of the acceptable GROUP BY expressions for use by - * check_ungrouped_columns(). + * substitute_grouped_columns(). * * We get the TLE, not just the expr, because GROUPING wants to know the * sortgroupref. @@ -1209,7 +1216,24 @@ parseCheckAggregates(ParseState *pstate, Query *qry) } /* - * Check the targetlist and HAVING clause for ungrouped variables. + * If there are any acceptable GROUP BY expressions, build an RTE and + * nsitem for the result of the grouping step. + */ + if (groupClauses) + { + pstate->p_grouping_nsitem = + addRangeTableEntryForGroup(pstate, groupClauses); + + /* Set qry->rtable again in case it was previously NIL */ + qry->rtable = pstate->p_rtable; + /* Mark the Query as having RTE_GROUP RTE */ + qry->hasGroupRTE = true; + } + + /* + * Replace grouped variables in the targetlist and HAVING clause with Vars + * that reference the RTE_GROUP RTE. Emit an error message if we find any + * ungrouped variables. * * Note: because we check resjunk tlist elements as well as regular ones, * this will also find ungrouped variables that came from ORDER BY and @@ -1225,10 +1249,12 @@ parseCheckAggregates(ParseState *pstate, Query *qry) have_non_var_grouping); if (hasJoinRTEs) clause = flatten_join_alias_vars(NULL, qry, clause); - check_ungrouped_columns(clause, pstate, qry, - groupClauses, groupClauseCommonVars, - have_non_var_grouping, - &func_grouped_rels); + qry->targetList = (List *) + substitute_grouped_columns(clause, pstate, qry, + groupClauses, groupClauseCommonVars, + gset_common, + have_non_var_grouping, + &func_grouped_rels); clause = (Node *) qry->havingQual; finalize_grouping_exprs(clause, pstate, qry, @@ -1236,10 +1262,12 @@ parseCheckAggregates(ParseState *pstate, Query *qry) have_non_var_grouping); if (hasJoinRTEs) clause = flatten_join_alias_vars(NULL, qry, clause); - check_ungrouped_columns(clause, pstate, qry, - groupClauses, groupClauseCommonVars, - have_non_var_grouping, - &func_grouped_rels); + qry->havingQual = + substitute_grouped_columns(clause, pstate, qry, + groupClauses, groupClauseCommonVars, + gset_common, + have_non_var_grouping, + &func_grouped_rels); /* * Per spec, aggregates can't appear in a recursive term. @@ -1253,14 +1281,16 @@ parseCheckAggregates(ParseState *pstate, Query *qry) } /* - * check_ungrouped_columns - - * Scan the given expression tree for ungrouped variables (variables - * that are not listed in the groupClauses list and are not within - * the arguments of aggregate functions). Emit a suitable error message - * if any are found. + * substitute_grouped_columns - + * Scan the given expression tree for grouped variables (variables that + * are listed in the groupClauses list) and replace them with Vars that + * reference the RTE_GROUP RTE. Emit a suitable error message if any + * ungrouped variables (variables that are not listed in the groupClauses + * list and are not within the arguments of aggregate functions) are + * found. * * NOTE: we assume that the given clause has been transformed suitably for - * parser output. This means we can use expression_tree_walker. + * parser output. This means we can use expression_tree_mutator. * * NOTE: we recognize grouping expressions in the main query, but only * grouping Vars in subqueries. For example, this will be rejected, @@ -1273,37 +1303,39 @@ parseCheckAggregates(ParseState *pstate, Query *qry) * This appears to require a whole custom version of equal(), which is * way more pain than the feature seems worth. */ -static void -check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry, - List *groupClauses, List *groupClauseCommonVars, - bool have_non_var_grouping, - List **func_grouped_rels) +static Node * +substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry, + List *groupClauses, List *groupClauseCommonVars, + List *gset_common, + bool have_non_var_grouping, + List **func_grouped_rels) { - check_ungrouped_columns_context context; + substitute_grouped_columns_context context; context.pstate = pstate; context.qry = qry; context.hasJoinRTEs = false; /* assume caller flattened join Vars */ context.groupClauses = groupClauses; context.groupClauseCommonVars = groupClauseCommonVars; + context.gset_common = gset_common; context.have_non_var_grouping = have_non_var_grouping; context.func_grouped_rels = func_grouped_rels; context.sublevels_up = 0; context.in_agg_direct_args = false; - check_ungrouped_columns_walker(node, &context); + return substitute_grouped_columns_mutator(node, &context); } -static bool -check_ungrouped_columns_walker(Node *node, - check_ungrouped_columns_context *context) +static Node * +substitute_grouped_columns_mutator(Node *node, + substitute_grouped_columns_context *context) { ListCell *gl; if (node == NULL) - return false; + return NULL; if (IsA(node, Const) || IsA(node, Param)) - return false; /* constants are always acceptable */ + return node; /* constants are always acceptable */ if (IsA(node, Aggref)) { @@ -1314,19 +1346,21 @@ check_ungrouped_columns_walker(Node *node, /* * If we find an aggregate call of the original level, do not * recurse into its normal arguments, ORDER BY arguments, or - * filter; ungrouped vars there are not an error. But we should - * check direct arguments as though they weren't in an aggregate. - * We set a special flag in the context to help produce a useful + * filter; grouped vars there do not need to be replaced and + * ungrouped vars there are not an error. But we should check + * direct arguments as though they weren't in an aggregate. We + * set a special flag in the context to help produce a useful * error message for ungrouped vars in direct arguments. */ - bool result; + agg = copyObject(agg); Assert(!context->in_agg_direct_args); context->in_agg_direct_args = true; - result = check_ungrouped_columns_walker((Node *) agg->aggdirectargs, - context); + agg->aggdirectargs = (List *) + substitute_grouped_columns_mutator((Node *) agg->aggdirectargs, + context); context->in_agg_direct_args = false; - return result; + return (Node *) agg; } /* @@ -1336,7 +1370,7 @@ check_ungrouped_columns_walker(Node *node, * levels, however. */ if ((int) agg->agglevelsup > context->sublevels_up) - return false; + return node; } if (IsA(node, GroupingFunc)) @@ -1346,7 +1380,7 @@ check_ungrouped_columns_walker(Node *node, /* handled GroupingFunc separately, no need to recheck at this level */ if ((int) grp->agglevelsup >= context->sublevels_up) - return false; + return node; } /* @@ -1358,12 +1392,20 @@ check_ungrouped_columns_walker(Node *node, */ if (context->have_non_var_grouping && context->sublevels_up == 0) { + int attnum = 0; + foreach(gl, context->groupClauses) { - TargetEntry *tle = lfirst(gl); + TargetEntry *tle = (TargetEntry *) lfirst(gl); + attnum++; if (equal(node, tle->expr)) - return false; /* acceptable, do not descend more */ + { + /* acceptable, replace it with a GROUP Var */ + return (Node *) buildGroupedVar(attnum, + tle->ressortgroupref, + context); + } } } @@ -1380,22 +1422,31 @@ check_ungrouped_columns_walker(Node *node, char *attname; if (var->varlevelsup != context->sublevels_up) - return false; /* it's not local to my query, ignore */ + return node; /* it's not local to my query, ignore */ /* * Check for a match, if we didn't do it above. */ if (!context->have_non_var_grouping || context->sublevels_up != 0) { + int attnum = 0; + foreach(gl, context->groupClauses) { - Var *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr; + TargetEntry *tle = (TargetEntry *) lfirst(gl); + Var *gvar = (Var *) tle->expr; + attnum++; if (IsA(gvar, Var) && gvar->varno == var->varno && gvar->varattno == var->varattno && gvar->varlevelsup == 0) - return false; /* acceptable, we're okay */ + { + /* acceptable, replace it with a GROUP Var */ + return (Node *) buildGroupedVar(attnum, + tle->ressortgroupref, + context); + } } } @@ -1416,7 +1467,7 @@ check_ungrouped_columns_walker(Node *node, * the constraintDeps list. */ if (list_member_int(*context->func_grouped_rels, var->varno)) - return false; /* previously proven acceptable */ + return node; /* previously proven acceptable */ Assert(var->varno > 0 && (int) var->varno <= list_length(context->pstate->p_rtable)); @@ -1431,7 +1482,7 @@ check_ungrouped_columns_walker(Node *node, { *context->func_grouped_rels = lappend_int(*context->func_grouped_rels, var->varno); - return false; /* acceptable */ + return node; /* acceptable */ } } @@ -1456,18 +1507,18 @@ check_ungrouped_columns_walker(Node *node, if (IsA(node, Query)) { /* Recurse into subselects */ - bool result; + Query *newnode; context->sublevels_up++; - result = query_tree_walker((Query *) node, - check_ungrouped_columns_walker, - (void *) context, - 0); + newnode = query_tree_mutator((Query *) node, + substitute_grouped_columns_mutator, + (void *) context, + 0); context->sublevels_up--; - return result; + return (Node *) newnode; } - return expression_tree_walker(node, check_ungrouped_columns_walker, - (void *) context); + return expression_tree_mutator(node, substitute_grouped_columns_mutator, + (void *) context); } /* @@ -1475,9 +1526,9 @@ check_ungrouped_columns_walker(Node *node, * Scan the given expression tree for GROUPING() and related calls, * and validate and process their arguments. * - * This is split out from check_ungrouped_columns above because it needs + * This is split out from substitute_grouped_columns above because it needs * to modify the nodes (which it does in-place, not via a mutator) while - * check_ungrouped_columns may see only a copy of the original thanks to + * substitute_grouped_columns may see only a copy of the original thanks to * flattening of join alias vars. So here, we flatten each individual * GROUPING argument as we see it before comparing it. */ @@ -1486,13 +1537,14 @@ finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry, List *groupClauses, bool hasJoinRTEs, bool have_non_var_grouping) { - check_ungrouped_columns_context context; + substitute_grouped_columns_context context; context.pstate = pstate; context.qry = qry; context.hasJoinRTEs = hasJoinRTEs; context.groupClauses = groupClauses; context.groupClauseCommonVars = NIL; + context.gset_common = NIL; context.have_non_var_grouping = have_non_var_grouping; context.func_grouped_rels = NULL; context.sublevels_up = 0; @@ -1502,7 +1554,7 @@ finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry, static bool finalize_grouping_exprs_walker(Node *node, - check_ungrouped_columns_context *context) + substitute_grouped_columns_context *context) { ListCell *gl; @@ -1643,6 +1695,38 @@ finalize_grouping_exprs_walker(Node *node, (void *) context); } +/* + * buildGroupedVar - + * build a Var node that references the RTE_GROUP RTE + */ +static Var * +buildGroupedVar(int attnum, Index ressortgroupref, + substitute_grouped_columns_context *context) +{ + Var *var; + ParseNamespaceItem *grouping_nsitem = context->pstate->p_grouping_nsitem; + ParseNamespaceColumn *nscol = grouping_nsitem->p_nscolumns + attnum - 1; + + Assert(nscol->p_varno == grouping_nsitem->p_rtindex); + Assert(nscol->p_varattno == attnum); + var = makeVar(nscol->p_varno, + nscol->p_varattno, + nscol->p_vartype, + nscol->p_vartypmod, + nscol->p_varcollid, + context->sublevels_up); + /* makeVar doesn't offer parameters for these, so set by hand: */ + var->varnosyn = nscol->p_varnosyn; + var->varattnosyn = nscol->p_varattnosyn; + + if (context->qry->groupingSets && + !list_member_int(context->gset_common, ressortgroupref)) + var->varnullingrels = + bms_add_member(var->varnullingrels, grouping_nsitem->p_rtindex); + + return var; +} + /* * Given a GroupingSet node, expand it and return a list of lists. |