diff options
Diffstat (limited to 'src/backend/commands/aggregatecmds.c')
-rw-r--r-- | src/backend/commands/aggregatecmds.c | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/src/backend/commands/aggregatecmds.c b/src/backend/commands/aggregatecmds.c index 59bc6e6fd8f..3424f842b9c 100644 --- a/src/backend/commands/aggregatecmds.c +++ b/src/backend/commands/aggregatecmds.c @@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *transfuncName = NIL; List *finalfuncName = NIL; List *combinefuncName = NIL; + List *serialfuncName = NIL; + List *deserialfuncName = NIL; List *mtransfuncName = NIL; List *minvtransfuncName = NIL; List *mfinalfuncName = NIL; @@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *sortoperatorName = NIL; TypeName *baseType = NULL; TypeName *transType = NULL; + TypeName *serialType = NULL; TypeName *mtransType = NULL; int32 transSpace = 0; int32 mtransSpace = 0; @@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, List *parameterDefaults; Oid variadicArgType; Oid transTypeId; + Oid serialTypeId = InvalidOid; Oid mtransTypeId = InvalidOid; char transTypeType; char mtransTypeType = 0; @@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, finalfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "combinefunc") == 0) combinefuncName = defGetQualifiedName(defel); + else if (pg_strcasecmp(defel->defname, "serialfunc") == 0) + serialfuncName = defGetQualifiedName(defel); + else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0) + deserialfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "msfunc") == 0) mtransfuncName = defGetQualifiedName(defel); else if (pg_strcasecmp(defel->defname, "minvfunc") == 0) @@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, } else if (pg_strcasecmp(defel->defname, "stype") == 0) transType = defGetTypeName(defel); + else if (pg_strcasecmp(defel->defname, "serialtype") == 0) + serialType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "stype1") == 0) transType = defGetTypeName(defel); else if (pg_strcasecmp(defel->defname, "sspace") == 0) @@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, format_type_be(transTypeId)))); } + if (serialType) + { + /* + * There's little point in having a serialization/deserialization + * function on aggregates that don't have an internal state, so let's + * just disallow this as it may help clear up any confusion or needless + * authoring of these functions. + */ + if (transTypeId != INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"", + format_type_be(INTERNALOID)))); + + serialTypeId = typenameTypeId(NULL, serialType); + + if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO && + !IsPolymorphicType(serialTypeId)) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization data type cannot be %s", + format_type_be(serialTypeId)))); + + /* + * We disallow INTERNAL serialType as the whole point of the + * serialized types is to allow the aggregate state to be output, + * and we cannot output INTERNAL. This check, combined with the one + * above ensures that the trans type and serialization type are not the + * same. + */ + if (serialTypeId == INTERNALOID) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization type cannot be \"%s\"", + format_type_be(serialTypeId)))); + + /* + * If serialType is specified then serialfuncName and deserialfuncName + * must be present; if not, then none of the serialization options + * should have been specified. + */ + if (serialfuncName == NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate serialization function must be specified when serialization type is specified"))); + + if (deserialfuncName == NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("aggregate deserialization function must be specified when serialization type is specified"))); + } + else + { + /* + * If serialization type was not specified then there shouldn't be a + * serialization function. + */ + if (serialfuncName != NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("must specify serialization type when specifying serialization function"))); + + /* likewise for the deserialization function */ + if (deserialfuncName != NIL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("must specify serialization type when specifying deserialization function"))); + } + /* * If a moving-aggregate transtype is specified, look that up. Same * restrictions as for transtype. @@ -387,6 +466,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, transfuncName, /* step function name */ finalfuncName, /* final function name */ combinefuncName, /* combine function name */ + serialfuncName, /* serial function name */ + deserialfuncName, /* deserial function name */ mtransfuncName, /* fwd trans function name */ minvtransfuncName, /* inv trans function name */ mfinalfuncName, /* final function name */ @@ -394,6 +475,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, mfinalfuncExtraArgs, sortoperatorName, /* sort operator name */ transTypeId, /* transition data type */ + serialTypeId, /* serialization data type */ transSpace, /* transition space */ mtransTypeId, /* transition data type */ mtransSpace, /* transition space */ |