diff options
Diffstat (limited to 'src/backend/parser/parse_agg.c')
-rw-r--r-- | src/backend/parser/parse_agg.c | 371 |
1 files changed, 371 insertions, 0 deletions
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c new file mode 100644 index 00000000000..b64b92079ef --- /dev/null +++ b/src/backend/parser/parse_agg.c @@ -0,0 +1,371 @@ +/*------------------------------------------------------------------------- + * + * parse_agg.c-- + * handle aggregates in parser + * + * Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * $Header: /cvsroot/pgsql/src/backend/parser/parse_agg.c,v 1.1 1997/11/25 22:05:34 momjian Exp $ + * + *------------------------------------------------------------------------- + */ +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "postgres.h" +#include "access/heapam.h" +#include "catalog/pg_aggregate.h" +#include "nodes/nodeFuncs.h" +#include "nodes/primnodes.h" +#include "nodes/relation.h" +#include "optimizer/clauses.h" +#include "parser/parse_agg.h" +#include "parser/parse_node.h" +#include "parser/parse_target.h" +#include "utils/syscache.h" + +#ifdef 0 +#include "nodes/nodes.h" +#include "nodes/params.h" +#include "parse.h" /* for AND, OR, etc. */ +#include "catalog/pg_type.h" /* for INT4OID, etc. */ +#include "catalog/pg_proc.h" +#include "utils/elog.h" +#include "utils/builtins.h" /* namecmp(), textout() */ +#include "utils/lsyscache.h" +#include "utils/palloc.h" +#include "utils/mcxt.h" +#include "utils/acl.h" +#include "nodes/makefuncs.h" /* for makeResdom(), etc. */ +#include "commands/sequence.h" + +#endif + +/* + * AddAggToParseState - + * add the aggregate to the list of unique aggregates in pstate. + * + * SIDE EFFECT: aggno in target list entry will be modified + */ +void +AddAggToParseState(ParseState *pstate, Aggreg *aggreg) +{ + List *ag; + int i; + + /* + * see if we have the aggregate already (we only need to record the + * aggregate once) + */ + i = 0; + foreach(ag, pstate->p_aggs) + { + Aggreg *a = lfirst(ag); + + if (!strcmp(a->aggname, aggreg->aggname) && + equal(a->target, aggreg->target)) + { + + /* fill in the aggno and we're done */ + aggreg->aggno = i; + return; + } + i++; + } + + /* not found, new aggregate */ + aggreg->aggno = i; + pstate->p_numAgg++; + pstate->p_aggs = lappend(pstate->p_aggs, aggreg); + return; +} + +/* + * finalizeAggregates - + * fill in qry_aggs from pstate. Also checks to make sure that aggregates + * are used in the proper place. + */ +void +finalizeAggregates(ParseState *pstate, Query *qry) +{ + List *l; + int i; + + parseCheckAggregates(pstate, qry); + + qry->qry_numAgg = pstate->p_numAgg; + qry->qry_aggs = + (Aggreg **) palloc(sizeof(Aggreg *) * qry->qry_numAgg); + i = 0; + foreach(l, pstate->p_aggs) + qry->qry_aggs[i++] = (Aggreg *) lfirst(l); +} + +/* + * contain_agg_clause-- + * Recursively find aggreg nodes from a clause. + * + * Returns true if any aggregate found. + */ +bool +contain_agg_clause(Node *clause) +{ + if (clause == NULL) + return FALSE; + else if (IsA(clause, Aggreg)) + return TRUE; + else if (IsA(clause, Iter)) + return contain_agg_clause(((Iter *) clause)->iterexpr); + else if (single_node(clause)) + return FALSE; + else if (or_clause(clause)) + { + List *temp; + + foreach(temp, ((Expr *) clause)->args) + if (contain_agg_clause(lfirst(temp))) + return TRUE; + return FALSE; + } + else if (is_funcclause(clause)) + { + List *temp; + + foreach(temp, ((Expr *) clause)->args) + if (contain_agg_clause(lfirst(temp))) + return TRUE; + return FALSE; + } + else if (IsA(clause, ArrayRef)) + { + List *temp; + + foreach(temp, ((ArrayRef *) clause)->refupperindexpr) + if (contain_agg_clause(lfirst(temp))) + return TRUE; + foreach(temp, ((ArrayRef *) clause)->reflowerindexpr) + if (contain_agg_clause(lfirst(temp))) + return TRUE; + if (contain_agg_clause(((ArrayRef *) clause)->refexpr)) + return TRUE; + if (contain_agg_clause(((ArrayRef *) clause)->refassgnexpr)) + return TRUE; + return FALSE; + } + else if (not_clause(clause)) + return contain_agg_clause((Node *) get_notclausearg((Expr *) clause)); + else if (is_opclause(clause)) + return (contain_agg_clause((Node *) get_leftop((Expr *) clause)) || + contain_agg_clause((Node *) get_rightop((Expr *) clause))); + + return FALSE; +} + +/* + * exprIsAggOrGroupCol - + * returns true if the expression does not contain non-group columns. + */ +bool +exprIsAggOrGroupCol(Node *expr, List *groupClause) +{ + List *gl; + + if (expr == NULL || IsA(expr, Const) || + IsA(expr, Param) ||IsA(expr, Aggreg)) + return TRUE; + + foreach(gl, groupClause) + { + GroupClause *grpcl = lfirst(gl); + + if (equal(expr, grpcl->entry->expr)) + return TRUE; + } + + if (IsA(expr, Expr)) + { + List *temp; + + foreach(temp, ((Expr *) expr)->args) + if (!exprIsAggOrGroupCol(lfirst(temp), groupClause)) + return FALSE; + return TRUE; + } + + return FALSE; +} + +/* + * tleIsAggOrGroupCol - + * returns true if the TargetEntry is Agg or GroupCol. + */ +bool +tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause) +{ + Node *expr = tle->expr; + List *gl; + + if (expr == NULL || IsA(expr, Const) ||IsA(expr, Param)) + return TRUE; + + foreach(gl, groupClause) + { + GroupClause *grpcl = lfirst(gl); + + if (tle->resdom->resno == grpcl->entry->resdom->resno) + { + if (contain_agg_clause((Node *) expr)) + elog(WARN, "parser: aggregates not allowed in GROUP BY clause"); + return TRUE; + } + } + + if (IsA(expr, Aggreg)) + return TRUE; + + if (IsA(expr, Expr)) + { + List *temp; + + foreach(temp, ((Expr *) expr)->args) + if (!exprIsAggOrGroupCol(lfirst(temp), groupClause)) + return FALSE; + return TRUE; + } + + return FALSE; +} + +/* + * parseCheckAggregates - + * this should really be done earlier but the current grammar + * cannot differentiate functions from aggregates. So we have do check + * here when the target list and the qualifications are finalized. + */ +void +parseCheckAggregates(ParseState *pstate, Query *qry) +{ + List *tl; + + Assert(pstate->p_numAgg > 0); + + /* + * aggregates never appear in WHERE clauses. (we have to check where + * clause first because if there is an aggregate, the check for + * non-group column in target list may fail.) + */ + if (contain_agg_clause(qry->qual)) + elog(WARN, "parser: aggregates not allowed in WHERE clause"); + + /* + * the target list can only contain aggregates, group columns and + * functions thereof. + */ + foreach(tl, qry->targetList) + { + TargetEntry *tle = lfirst(tl); + + if (!tleIsAggOrGroupCol(tle, qry->groupClause)) + elog(WARN, + "parser: illegal use of aggregates or non-group column in target list"); + } + + /* + * the expression specified in the HAVING clause has the same + * restriction as those in the target list. + */ +/* + * Need to change here when we get HAVING works. Currently + * qry->havingQual is NULL. - vadim 04/05/97 + if (!exprIsAggOrGroupCol(qry->havingQual, qry->groupClause)) + elog(WARN, + "parser: illegal use of aggregates or non-group column in HAVING clause"); + */ + return; +} + + +Aggreg * +ParseAgg(char *aggname, Oid basetype, Node *target) +{ + Oid fintype; + Oid vartype; + Oid xfn1; + Form_pg_aggregate aggform; + Aggreg *aggreg; + HeapTuple theAggTuple; + + theAggTuple = SearchSysCacheTuple(AGGNAME, PointerGetDatum(aggname), + ObjectIdGetDatum(basetype), + 0, 0); + if (!HeapTupleIsValid(theAggTuple)) + { + elog(WARN, "aggregate %s does not exist", aggname); + } + + aggform = (Form_pg_aggregate) GETSTRUCT(theAggTuple); + fintype = aggform->aggfinaltype; + xfn1 = aggform->aggtransfn1; + + if (nodeTag(target) != T_Var && nodeTag(target) != T_Expr) + elog(WARN, "parser: aggregate can only be applied on an attribute or expression"); + + /* only aggregates with transfn1 need a base type */ + if (OidIsValid(xfn1)) + { + basetype = aggform->aggbasetype; + if (nodeTag(target) == T_Var) + vartype = ((Var *) target)->vartype; + else + vartype = ((Expr *) target)->typeOid; + + if (basetype != vartype) + { + Type tp1, + tp2; + + tp1 = typeidType(basetype); + tp2 = typeidType(vartype); + elog(NOTICE, "Aggregate type mismatch:"); + elog(WARN, "%s works on %s, not %s", aggname, + typeTypeName(tp1), typeTypeName(tp2)); + } + } + + aggreg = makeNode(Aggreg); + aggreg->aggname = pstrdup(aggname); + aggreg->basetype = aggform->aggbasetype; + aggreg->aggtype = fintype; + + aggreg->target = target; + + return aggreg; +} + +/* + * Error message when aggregate lookup fails that gives details of the + * basetype + */ +void +agg_error(char *caller, char *aggname, Oid basetypeID) +{ + + /* + * basetypeID that is Invalid (zero) means aggregate over all types. + * (count) + */ + + if (basetypeID == InvalidOid) + { + elog(WARN, "%s: aggregate '%s' for all types does not exist", caller, aggname); + } + else + { + elog(WARN, "%s: aggregate '%s' for '%s' does not exist", caller, aggname, + typeidTypeName(basetypeID)); + } +} + |