diff options
Diffstat (limited to 'src/pl/plpython/plpython.c')
-rw-r--r-- | src/pl/plpython/plpython.c | 129 |
1 files changed, 120 insertions, 9 deletions
diff --git a/src/pl/plpython/plpython.c b/src/pl/plpython/plpython.c index 4744ee7bebd..4cc0708dcb9 100644 --- a/src/pl/plpython/plpython.c +++ b/src/pl/plpython/plpython.c @@ -268,6 +268,28 @@ typedef struct PLySubtransactionObject bool exited; } PLySubtransactionObject; +/* A list of all known exceptions, generated from backend/utils/errcodes.txt */ +typedef struct ExceptionMap +{ + char *name; + char *classname; + int sqlstate; +} ExceptionMap; + +static const ExceptionMap exception_map[] = { +#include "spiexceptions.h" + {NULL, NULL, 0} +}; + +/* A hash table mapping sqlstates to exceptions, for speedy lookup */ +static HTAB *PLy_spi_exceptions; + +typedef struct PLyExceptionEntry +{ + int sqlstate; /* hash key, must be first */ + PyObject *exc; /* corresponding exception */ +} PLyExceptionEntry; + /* function declarations */ @@ -310,7 +332,7 @@ __attribute__((format(printf, 2, 5))) __attribute__((format(printf, 3, 5))); /* like PLy_exception_set, but conserve more fields from ErrorData */ -static void PLy_spi_exception_set(ErrorData *edata); +static void PLy_spi_exception_set(PyObject *excclass, ErrorData *edata); /* Get the innermost python procedure called from the backend */ static char *PLy_procedure_name(PLyProcedure *); @@ -3013,6 +3035,10 @@ static PyMethodDef PLy_methods[] = { {NULL, NULL, 0, NULL} }; +static PyMethodDef PLy_exc_methods[] = { + {NULL, NULL, 0, NULL} +}; + #if PY_MAJOR_VERSION >= 3 static PyModuleDef PLy_module = { PyModuleDef_HEAD_INIT, /* m_base */ @@ -3021,6 +3047,18 @@ static PyModuleDef PLy_module = { -1, /* m_size */ PLy_methods, /* m_methods */ }; + +static PyModuleDef PLy_exc_module = { + PyModuleDef_HEAD_INIT, /* m_base */ + "spiexceptions", /* m_name */ + NULL, /* m_doc */ + -1, /* m_size */ + PLy_exc_methods, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL /* m_free */ +}; #endif /* plan object methods */ @@ -3318,6 +3356,8 @@ PLy_spi_prepare(PyObject *self, PyObject *args) PG_CATCH(); { ErrorData *edata; + PLyExceptionEntry *entry; + PyObject *exc; /* Save error info */ MemoryContextSwitchTo(oldcontext); @@ -3338,8 +3378,14 @@ PLy_spi_prepare(PyObject *self, PyObject *args) */ SPI_restore_connection(); + /* Look up the correct exception */ + entry = hash_search(PLy_spi_exceptions, &(edata->sqlerrcode), + HASH_FIND, NULL); + /* We really should find it, but just in case have a fallback */ + Assert(entry != NULL); + exc = entry ? entry->exc : PLy_exc_spi_error; /* Make Python raise the exception */ - PLy_spi_exception_set(edata); + PLy_spi_exception_set(exc, edata); return NULL; } PG_END_TRY(); @@ -3490,6 +3536,8 @@ PLy_spi_execute_plan(PyObject *ob, PyObject *list, long limit) { int k; ErrorData *edata; + PLyExceptionEntry *entry; + PyObject *exc; /* Save error info */ MemoryContextSwitchTo(oldcontext); @@ -3521,8 +3569,14 @@ PLy_spi_execute_plan(PyObject *ob, PyObject *list, long limit) */ SPI_restore_connection(); + /* Look up the correct exception */ + entry = hash_search(PLy_spi_exceptions, &(edata->sqlerrcode), + HASH_FIND, NULL); + /* We really should find it, but just in case have a fallback */ + Assert(entry != NULL); + exc = entry ? entry->exc : PLy_exc_spi_error; /* Make Python raise the exception */ - PLy_spi_exception_set(edata); + PLy_spi_exception_set(exc, edata); return NULL; } PG_END_TRY(); @@ -3582,7 +3636,9 @@ PLy_spi_execute_query(char *query, long limit) } PG_CATCH(); { - ErrorData *edata; + ErrorData *edata; + PLyExceptionEntry *entry; + PyObject *exc; /* Save error info */ MemoryContextSwitchTo(oldcontext); @@ -3601,8 +3657,14 @@ PLy_spi_execute_query(char *query, long limit) */ SPI_restore_connection(); + /* Look up the correct exception */ + entry = hash_search(PLy_spi_exceptions, &edata->sqlerrcode, + HASH_FIND, NULL); + /* We really should find it, but just in case have a fallback */ + Assert(entry != NULL); + exc = entry ? entry->exc : PLy_exc_spi_error; /* Make Python raise the exception */ - PLy_spi_exception_set(edata); + PLy_spi_exception_set(exc, edata); return NULL; } PG_END_TRY(); @@ -3832,9 +3894,49 @@ PLy_subtransaction_exit(PyObject *self, PyObject *args) /* * Add exceptions to the plpy module */ + +/* + * Add all the autogenerated exceptions as subclasses of SPIError + */ +static void +PLy_generate_spi_exceptions(PyObject *mod, PyObject *base) +{ + int i; + + for (i = 0; exception_map[i].name != NULL; i++) + { + bool found; + PyObject *exc; + PLyExceptionEntry *entry; + PyObject *sqlstate; + PyObject *dict = PyDict_New(); + + sqlstate = PyString_FromString(unpack_sql_state(exception_map[i].sqlstate)); + PyDict_SetItemString(dict, "sqlstate", sqlstate); + Py_DECREF(sqlstate); + exc = PyErr_NewException(exception_map[i].name, base, dict); + PyModule_AddObject(mod, exception_map[i].classname, exc); + entry = hash_search(PLy_spi_exceptions, &exception_map[i].sqlstate, + HASH_ENTER, &found); + entry->exc = exc; + Assert(!found); + } +} + static void PLy_add_exceptions(PyObject *plpy) { + PyObject *excmod; + HASHCTL hash_ctl; + +#if PY_MAJOR_VERSION < 3 + excmod = Py_InitModule("spiexceptions", PLy_exc_methods); +#else + excmod = PyModule_Create(&PLy_exc_module); +#endif + if (PyModule_AddObject(plpy, "spiexceptions", excmod) < 0) + PLy_elog(ERROR, "failed to add the spiexceptions module"); + PLy_exc_error = PyErr_NewException("plpy.Error", NULL, NULL); PLy_exc_fatal = PyErr_NewException("plpy.Fatal", NULL, NULL); PLy_exc_spi_error = PyErr_NewException("plpy.SPIError", NULL, NULL); @@ -3845,6 +3947,15 @@ PLy_add_exceptions(PyObject *plpy) PyModule_AddObject(plpy, "Fatal", PLy_exc_fatal); Py_INCREF(PLy_exc_spi_error); PyModule_AddObject(plpy, "SPIError", PLy_exc_spi_error); + + memset(&hash_ctl, 0, sizeof(hash_ctl)); + hash_ctl.keysize = sizeof(int); + hash_ctl.entrysize = sizeof(PLyExceptionEntry); + hash_ctl.hash = tag_hash; + PLy_spi_exceptions = hash_create("SPI exceptions", 256, + &hash_ctl, HASH_ELEM | HASH_FUNCTION); + + PLy_generate_spi_exceptions(excmod, PLy_exc_spi_error); } #if PY_MAJOR_VERSION >= 3 @@ -4205,7 +4316,7 @@ PLy_exception_set_plural(PyObject *exc, * internal query and error position. */ static void -PLy_spi_exception_set(ErrorData *edata) +PLy_spi_exception_set(PyObject *excclass, ErrorData *edata) { PyObject *args = NULL; PyObject *spierror = NULL; @@ -4215,8 +4326,8 @@ PLy_spi_exception_set(ErrorData *edata) if (!args) goto failure; - /* create a new SPIError with the error message as the parameter */ - spierror = PyObject_CallObject(PLy_exc_spi_error, args); + /* create a new SPI exception with the error message as the parameter */ + spierror = PyObject_CallObject(excclass, args); if (!spierror) goto failure; @@ -4228,7 +4339,7 @@ PLy_spi_exception_set(ErrorData *edata) if (PyObject_SetAttrString(spierror, "spidata", spidata) == -1) goto failure; - PyErr_SetObject(PLy_exc_spi_error, spierror); + PyErr_SetObject(excclass, spierror); Py_DECREF(args); Py_DECREF(spierror); |