From c2f25b51798c15b59300c0881c06262cab0c452d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sun, 28 Dec 2025 12:49:29 +0100 Subject: [PATCH 1/4] gh-142830: convert callback context into a full-fledged `PyObject *` --- Makefile.pre.in | 2 +- Modules/Setup.stdlib.in | 2 +- Modules/_sqlite/clinic/context.c.h | 30 +++++ Modules/_sqlite/connection.c | 192 ++++++++++------------------- Modules/_sqlite/connection.h | 15 +-- Modules/_sqlite/context.c | 121 ++++++++++++++++++ Modules/_sqlite/context.h | 37 ++++++ Modules/_sqlite/module.c | 1 + Modules/_sqlite/module.h | 1 + PCbuild/_sqlite3.vcxproj | 4 +- PCbuild/_sqlite3.vcxproj.filters | 8 +- 11 files changed, 275 insertions(+), 138 deletions(-) create mode 100644 Modules/_sqlite/clinic/context.c.h create mode 100644 Modules/_sqlite/context.c create mode 100644 Modules/_sqlite/context.h diff --git a/Makefile.pre.in b/Makefile.pre.in index a6beb96d12a3f2..36c81b0337fddd 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -3429,7 +3429,7 @@ MODULE__SSL_DEPS=$(srcdir)/Modules/_ssl.h $(srcdir)/Modules/_ssl/cert.c $(srcdir MODULE__TESTCAPI_DEPS=$(srcdir)/Modules/_testcapi/parts.h $(srcdir)/Modules/_testcapi/util.h MODULE__TESTLIMITEDCAPI_DEPS=$(srcdir)/Modules/_testlimitedcapi/testcapi_long.h $(srcdir)/Modules/_testlimitedcapi/parts.h $(srcdir)/Modules/_testlimitedcapi/util.h MODULE__TESTINTERNALCAPI_DEPS=$(srcdir)/Modules/_testinternalcapi/parts.h -MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h +MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/context.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h MODULE__ZSTD_DEPS=$(srcdir)/Modules/_zstd/_zstdmodule.h $(srcdir)/Modules/_zstd/buffer.h $(srcdir)/Modules/_zstd/zstddict.h CODECS_COMMON_HEADERS=$(srcdir)/Modules/cjkcodecs/multibytecodec.h $(srcdir)/Modules/cjkcodecs/cjkcodecs.h diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in index 2a4b937ce6bf80..59d649825fbe9d 100644 --- a/Modules/Setup.stdlib.in +++ b/Modules/Setup.stdlib.in @@ -149,7 +149,7 @@ # needs -lncurses[w] and -lpanel[w] @MODULE__CURSES_PANEL_TRUE@_curses_panel _curses_panel.c -@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c +@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/context.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c # needs -lssl and -lcrypt @MODULE__SSL_TRUE@_ssl _ssl.c diff --git a/Modules/_sqlite/clinic/context.c.h b/Modules/_sqlite/clinic/context.c.h new file mode 100644 index 00000000000000..9788f819ac7f68 --- /dev/null +++ b/Modules/_sqlite/clinic/context.c.h @@ -0,0 +1,30 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#include "pycore_modsupport.h" // _PyArg_CheckPositional() + +static PyObject * +callback_context_new_impl(PyTypeObject *type, PyObject *callable); + +static PyObject * +callback_context_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) +{ + PyObject *return_value = NULL; + PyTypeObject *base_tp = clinic_state()->CallbackContextType; + PyObject *callable; + + if ((type == base_tp || type->tp_init == base_tp->tp_init) && + !_PyArg_NoKeywords("_CallbackContext", kwargs)) { + goto exit; + } + if (!_PyArg_CheckPositional("_CallbackContext", PyTuple_GET_SIZE(args), 1, 1)) { + goto exit; + } + callable = PyTuple_GET_ITEM(args, 0); + return_value = callback_context_new_impl(type, callable); + +exit: + return return_value; +} +/*[clinic end generated code: output=370246c27daeaaa1 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 83ff8e60557c07..53d1953b55e5b0 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -145,9 +145,6 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp /*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/ static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self); -static void free_callback_context(callback_context *ctx); -static void set_callback_context(callback_context **ctx_pp, - callback_context *ctx); static int connection_close(pysqlite_Connection *self); PyObject *_pysqlite_query_execute(pysqlite_Cursor *, int, PyObject *, PyObject *); @@ -377,13 +374,13 @@ output pop [clinic start generated code]*/ /*[clinic end generated code: output=da39a3ee5e6b4b0d input=b899ba9273edcce7]*/ -#define VISIT_CALLBACK_CONTEXT(ctx) \ -do { \ - if (ctx) { \ - Py_VISIT(ctx->callable); \ - Py_VISIT(ctx->module); \ - } \ -} while (0) +static void +clear_callback_contexts(pysqlite_Connection *self) +{ + Py_CLEAR(self->trace_ctx); + Py_CLEAR(self->progress_ctx); + Py_CLEAR(self->authorizer_ctx); +} static int connection_traverse(PyObject *op, visitproc visit, void *arg) @@ -395,22 +392,12 @@ connection_traverse(PyObject *op, visitproc visit, void *arg) Py_VISIT(self->blobs); Py_VISIT(self->row_factory); Py_VISIT(self->text_factory); - VISIT_CALLBACK_CONTEXT(self->trace_ctx); - VISIT_CALLBACK_CONTEXT(self->progress_ctx); - VISIT_CALLBACK_CONTEXT(self->authorizer_ctx); -#undef VISIT_CALLBACK_CONTEXT + Py_VISIT(self->trace_ctx); + Py_VISIT(self->progress_ctx); + Py_VISIT(self->authorizer_ctx); return 0; } -static inline void -clear_callback_context(callback_context *ctx) -{ - if (ctx != NULL) { - Py_CLEAR(ctx->callable); - Py_CLEAR(ctx->module); - } -} - static int connection_clear(PyObject *op) { @@ -420,20 +407,10 @@ connection_clear(PyObject *op) Py_CLEAR(self->blobs); Py_CLEAR(self->row_factory); Py_CLEAR(self->text_factory); - clear_callback_context(self->trace_ctx); - clear_callback_context(self->progress_ctx); - clear_callback_context(self->authorizer_ctx); + clear_callback_contexts(self); return 0; } -static void -free_callback_contexts(pysqlite_Connection *self) -{ - set_callback_context(&self->trace_ctx, NULL); - set_callback_context(&self->progress_ctx, NULL); - set_callback_context(&self->authorizer_ctx, NULL); -} - static void remove_callbacks(sqlite3 *db) { @@ -474,7 +451,7 @@ connection_close(pysqlite_Connection *self) (void)sqlite3_close_v2(db); Py_END_ALLOW_THREADS - free_callback_contexts(self); + clear_callback_contexts(self); return rc; } @@ -814,7 +791,7 @@ _pysqlite_set_result(sqlite3_context* context, PyObject* py_val) sqlite3_result_blob(context, view.buf, (int)view.len, SQLITE_TRANSIENT); PyBuffer_Release(&view); } else { - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); PyErr_Format(ctx->state->ProgrammingError, "User-defined functions cannot return '%s' values to " "SQLite", @@ -893,7 +870,7 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc, } static void -print_or_clear_traceback(callback_context *ctx) +print_or_clear_traceback(pysqlite_CallbackContext *ctx) { assert(ctx != NULL); assert(ctx->state != NULL); @@ -920,7 +897,7 @@ set_sqlite_error(sqlite3_context *context, const char *msg) else { sqlite3_result_error(context, msg, -1); } - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); print_or_clear_traceback(ctx); } @@ -935,7 +912,7 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) args = _pysqlite_build_py_params(context, argc, argv); if (args) { - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); py_retval = PyObject_CallObject(ctx->callable, args); Py_DECREF(args); @@ -963,7 +940,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) PyObject** aggregate_instance; PyObject* stepmethod = NULL; - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); @@ -1032,7 +1009,7 @@ final_callback(sqlite3_context *context) // Keep the exception (if any) of the last call to step, value, or inverse PyObject *exc = PyErr_GetRaisedException(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); function_result = PyObject_CallMethodNoArgs(*aggregate_instance, ctx->state->str_finalize); @@ -1095,55 +1072,14 @@ _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self) return 0; } -/* Allocate a UDF/callback context structure. In order to ensure that the state - * pointer always outlives the callback context, we make sure it owns a - * reference to the module itself. create_callback_context() is always called - * from connection methods, so we use the defining class to fetch the module - * pointer. - */ -static callback_context * -create_callback_context(PyTypeObject *cls, PyObject *callable) -{ - callback_context *ctx = PyMem_Malloc(sizeof(callback_context)); - if (ctx != NULL) { - PyObject *module = PyType_GetModule(cls); - ctx->callable = Py_NewRef(callable); - ctx->module = Py_NewRef(module); - ctx->state = pysqlite_get_state(module); - } - return ctx; -} - -static void -free_callback_context(callback_context *ctx) -{ - assert(ctx != NULL); - Py_XDECREF(ctx->callable); - Py_XDECREF(ctx->module); - PyMem_Free(ctx); -} - -static void -set_callback_context(callback_context **ctx_pp, callback_context *ctx) -{ - assert(ctx_pp != NULL); - callback_context *tmp = *ctx_pp; - *ctx_pp = ctx; - if (tmp != NULL) { - free_callback_context(tmp); - } -} - static void destructor_callback(void *ctx) { - if (ctx != NULL) { - // This function may be called without the GIL held, so we need to - // ensure that we destroy 'ctx' with the GIL held. - PyGILState_STATE gstate = PyGILState_Ensure(); - free_callback_context((callback_context *)ctx); - PyGILState_Release(gstate); - } + // This function may be called without the GIL held, so we need to + // ensure that we destroy 'ctx' with the GIL held. + PyGILState_STATE gstate = PyGILState_Ensure(); + Py_DECREF(ctx); + PyGILState_Release(gstate); } static int @@ -1194,7 +1130,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, if (deterministic) { flags |= SQLITE_DETERMINISTIC; } - callback_context *ctx = create_callback_context(cls, func); + PyObject *ctx = pysqlite_create_callback_context(self->state, func); if (ctx == NULL) { return NULL; } @@ -1225,7 +1161,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) { PyGILState_STATE gilstate = PyGILState_Ensure(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); int size = sizeof(PyObject *); @@ -1273,7 +1209,7 @@ value_callback(sqlite3_context *context) { PyGILState_STATE gilstate = PyGILState_Ensure(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); int size = sizeof(PyObject *); @@ -1344,7 +1280,7 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, 0, 0, 0, 0, 0, 0); } else { - callback_context *ctx = create_callback_context(cls, aggregate_class); + PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class); if (ctx == NULL) { return NULL; } @@ -1395,7 +1331,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, return NULL; } - callback_context *ctx = create_callback_context(cls, aggregate_class); + PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class); if (ctx == NULL) { return NULL; } @@ -1413,7 +1349,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, } static int -authorizer_callback(void *ctx, int action, const char *arg1, +authorizer_callback(void *ctx_vp, int action, const char *arg1, const char *arg2 , const char *dbname, const char *access_attempt_source) { @@ -1422,8 +1358,9 @@ authorizer_callback(void *ctx, int action, const char *arg1, PyObject *ret; int rc = SQLITE_DENY; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + PyObject *callable = ctx->callable; ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname, access_attempt_source); @@ -1450,15 +1387,16 @@ authorizer_callback(void *ctx, int action, const char *arg1, } static int -progress_callback(void *ctx) +progress_callback(void *ctx_vp) { PyGILState_STATE gilstate = PyGILState_Ensure(); int rc; PyObject *ret; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + PyObject *callable = ctx->callable; ret = PyObject_CallNoArgs(callable); if (!ret) { /* abort query if error occurred */ @@ -1483,7 +1421,7 @@ progress_callback(void *ctx) * to ensure future compatibility. */ static int -trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) +trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) { if (type != SQLITE_TRACE_STMT) { return 0; @@ -1491,8 +1429,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyGILState_STATE gilstate = PyGILState_Ensure(); - assert(ctx != NULL); - pysqlite_state *state = ((callback_context *)ctx)->state; + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + pysqlite_state *state = ctx->state; assert(state != NULL); PyObject *py_statement = NULL; @@ -1506,7 +1445,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyErr_SetString(state->DataError, "Expanded SQL string exceeds the maximum string length"); - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); // Fall back to unexpanded sql py_statement = PyUnicode_FromString((const char *)sql); @@ -1516,13 +1455,12 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) sqlite3_free((void *)expanded_sql); } if (py_statement) { - PyObject *callable = ((callback_context *)ctx)->callable; - PyObject *ret = PyObject_CallOneArg(callable, py_statement); + PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement); Py_DECREF(py_statement); Py_XDECREF(ret); } if (PyErr_Occurred()) { - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); } exit: @@ -1552,21 +1490,22 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self, int rc; if (callable == Py_None) { + /* None clears the authorizer callback previously set */ rc = sqlite3_set_authorizer(self->db, NULL, NULL); - set_callback_context(&self->authorizer_ctx, NULL); + Py_CLEAR(self->authorizer_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } rc = sqlite3_set_authorizer(self->db, authorizer_callback, ctx); - set_callback_context(&self->authorizer_ctx, ctx); + Py_XSETREF(self->authorizer_ctx, ctx); } if (rc != SQLITE_OK) { PyErr_SetString(self->OperationalError, "Error setting authorizer callback"); - set_callback_context(&self->authorizer_ctx, NULL); + Py_CLEAR(self->authorizer_ctx); return NULL; } Py_RETURN_NONE; @@ -1604,15 +1543,15 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self, if (callable == Py_None) { /* None clears the progress handler previously set */ sqlite3_progress_handler(self->db, 0, 0, (void*)0); - set_callback_context(&self->progress_ctx, NULL); + Py_CLEAR(self->progress_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } sqlite3_progress_handler(self->db, n, progress_callback, ctx); - set_callback_context(&self->progress_ctx, ctx); + Py_XSETREF(self->progress_ctx, ctx); } Py_RETURN_NONE; } @@ -1637,6 +1576,7 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self, return NULL; } + int rc; if (callable == Py_None) { /* * None clears the trace callback previously set @@ -1645,18 +1585,22 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self, * - https://sqlite.org/c3ref/c_trace.html * - https://sqlite.org/c3ref/trace_v2.html */ - sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0); - set_callback_context(&self->trace_ctx, NULL); + rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0); + Py_CLEAR(self->trace_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } - sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx); - set_callback_context(&self->trace_ctx, ctx); + rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx); + Py_XSETREF(self->trace_ctx, ctx); + } + if (rc != SQLITE_OK) { + PyErr_SetString(self->OperationalError, "Error setting trace callback"); + Py_CLEAR(self->trace_ctx); + return NULL; } - Py_RETURN_NONE; } @@ -1966,7 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data, goto finally; } - callback_context *ctx = (callback_context *)context; + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context); assert(ctx != NULL); PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs. size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; @@ -2185,7 +2129,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, return NULL; } - callback_context *ctx = NULL; + PyObject *ctx = NULL; int rc; int flags = SQLITE_UTF8; if (callable == Py_None) { @@ -2193,11 +2137,13 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, NULL, NULL, NULL); } else { + // TODO(picnixz): defer this check to the context's constructor + // and do it for all other functions that create a context. if (!PyCallable_Check(callable)) { PyErr_SetString(PyExc_TypeError, "parameter must be callable"); return NULL; } - ctx = create_callback_context(cls, callable); + ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } @@ -2211,9 +2157,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, * called if sqlite3_create_collation_v2() fails, so we have to free * the context before returning. */ - if (callable != Py_None) { - free_callback_context(ctx); - } + Py_XDECREF(ctx); set_error_from_db(self->state, self->db); return NULL; } diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 7a748ee3ea0c58..20c20d772befa3 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -23,21 +23,16 @@ #ifndef PYSQLITE_CONNECTION_H #define PYSQLITE_CONNECTION_H + #include "Python.h" #include "pythread.h" #include "structmember.h" +#include "context.h" #include "module.h" #include "sqlite3.h" -typedef struct _callback_context -{ - PyObject *callable; - PyObject *module; - pysqlite_state *state; -} callback_context; - enum autocommit_mode { AUTOCOMMIT_LEGACY = LEGACY_TRANSACTION_CONTROL, AUTOCOMMIT_ENABLED = 1, @@ -88,9 +83,9 @@ typedef struct PyObject* text_factory; // Remember contexts used by the trace, progress, and authoriser callbacks - callback_context *trace_ctx; - callback_context *progress_ctx; - callback_context *authorizer_ctx; + PyObject *trace_ctx; + PyObject *progress_ctx; + PyObject *authorizer_ctx; /* Exception objects: borrowed refs. */ PyObject* Warning; diff --git a/Modules/_sqlite/context.c b/Modules/_sqlite/context.c new file mode 100644 index 00000000000000..01077eb33215aa --- /dev/null +++ b/Modules/_sqlite/context.c @@ -0,0 +1,121 @@ +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "context.h" +#include "module.h" + +#define clinic_state() (pysqlite_get_state_by_type(type)) +#include "clinic/context.c.h" +#undef clinic_state + +/*[clinic input] +module _sqlite3 +class _sqlite3._CallbackContext "pysqlite_CallbackContext *" "clinic_state()->CallbackContextType" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=f06c18b7c3666e05]*/ + +/*[clinic input] +@classmethod +_sqlite3._CallbackContext.__new__ as callback_context_new + + callable: object + / + +[clinic start generated code]*/ + +static PyObject * +callback_context_new_impl(PyTypeObject *type, PyObject *callable) +/*[clinic end generated code: output=46b4f355475d88cc input=ae33656b48f65a6d]*/ +{ + PyObject *module = PyType_GetModuleByDef(type, &_sqlite3module); + assert(module != NULL); + pysqlite_state *state = pysqlite_get_state_by_type(type); + assert(state != NULL); + + pysqlite_CallbackContext *ctx = PyObject_GC_New( + pysqlite_CallbackContext, + state->CallbackContextType + ); + if (ctx == NULL) { + return NULL; + } + + // TODO(picnixz): check that 'callable' is effectively callable + // instead of relying on a generic TypeError when attempting to + // call it. + ctx->callable = Py_NewRef(callable); + ctx->module = Py_NewRef(module); + ctx->state = state; + PyObject_GC_Track(ctx); + return (PyObject *)ctx; + +} + +static int +callback_context_clear(PyObject *op) +{ + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op); + Py_CLEAR(ctx->callable); + Py_CLEAR(ctx->module); + return 0; +} + +static void +callback_context_dealloc(PyObject *op) +{ + PyTypeObject *type = Py_TYPE(op); + PyObject_GC_UnTrack(op); + (void)type->tp_clear(op); + type->tp_free(op); + Py_DECREF(type); +} + +static int +callback_context_traverse(PyObject *op, visitproc visit, void *arg) +{ + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op); + Py_VISIT(Py_TYPE(op)); + Py_VISIT(ctx->callable); + Py_VISIT(ctx->module); + return 0; +} + +static PyType_Slot callback_context_slots[] = { + {Py_tp_new, callback_context_new}, + {Py_tp_clear, callback_context_clear}, + {Py_tp_dealloc, callback_context_dealloc}, + {Py_tp_traverse, callback_context_traverse}, + {0, NULL}, +}; + +static PyType_Spec callback_context_spec = { + .name = MODULE_NAME "._CallbackContext", + .basicsize = sizeof(pysqlite_CallbackContext), + .flags = ( + Py_TPFLAGS_DEFAULT + | Py_TPFLAGS_DISALLOW_INSTANTIATION + | Py_TPFLAGS_IMMUTABLETYPE + | Py_TPFLAGS_HAVE_GC + ), + .slots = callback_context_slots, +}; + +PyObject * +pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable) +{ + PyTypeObject *type = state->CallbackContextType; + return callback_context_new_impl(type, callable); +} + +int +pysqlite_context_setup_types(PyObject *module) +{ + PyObject *type = PyType_FromModuleAndSpec(module, &callback_context_spec, NULL); + if (type == NULL) { + return -1; + } + pysqlite_state *state = pysqlite_get_state(module); + state->CallbackContextType = (PyTypeObject *)type; + return 0; +} diff --git a/Modules/_sqlite/context.h b/Modules/_sqlite/context.h new file mode 100644 index 00000000000000..10f893ad50e382 --- /dev/null +++ b/Modules/_sqlite/context.h @@ -0,0 +1,37 @@ +#ifndef PYSQLITE_CALLBACK_CONTEXT_H +#define PYSQLITE_CALLBACK_CONTEXT_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "Python.h" + +#include "module.h" + +typedef struct { + PyObject_HEAD + PyObject *callable; + PyObject *module; + pysqlite_state *state; +} pysqlite_CallbackContext; + +#define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op)) + +/* Allocate a UDF/callback context structure. In order to ensure that the state + * pointer always outlives the callback context, we make sure it owns a + * reference to the module itself. create_callback_context() is always called + * from connection methods, so we use the defining class to fetch the module + * pointer. + */ +PyObject * +pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable); + +int +pysqlite_context_setup_types(PyObject *module); + +#ifdef __cplusplus +} +#endif + +#endif // !PYSQLITE_CALLBACK_CONTEXT_H diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 831dd9219f77ab..923268159a4197 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -686,6 +686,7 @@ module_exec(PyObject *module) if ((pysqlite_row_setup_types(module) < 0) || (pysqlite_cursor_setup_types(module) < 0) || + (pysqlite_context_setup_types(module) < 0) || (pysqlite_connection_setup_types(module) < 0) || (pysqlite_statement_setup_types(module) < 0) || (pysqlite_prepare_protocol_setup_types(module) < 0) || diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index a4ca45cf6326a9..4fee440146c7f6 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -55,6 +55,7 @@ typedef struct { int enable_callback_tracebacks; PyTypeObject *BlobType; + PyTypeObject *CallbackContextType; PyTypeObject *ConnectionType; PyTypeObject *CursorType; PyTypeObject *PrepareProtocolType; diff --git a/PCbuild/_sqlite3.vcxproj b/PCbuild/_sqlite3.vcxproj index 9ae0a0fc3a009d..4919b28293875e 100644 --- a/PCbuild/_sqlite3.vcxproj +++ b/PCbuild/_sqlite3.vcxproj @@ -99,6 +99,7 @@ + @@ -110,6 +111,7 @@ + @@ -135,4 +137,4 @@ - \ No newline at end of file + diff --git a/PCbuild/_sqlite3.vcxproj.filters b/PCbuild/_sqlite3.vcxproj.filters index f4a265eba7dd80..6a930a01d2e519 100644 --- a/PCbuild/_sqlite3.vcxproj.filters +++ b/PCbuild/_sqlite3.vcxproj.filters @@ -15,6 +15,9 @@ Header Files + + Header Files + Header Files @@ -44,6 +47,9 @@ Source Files + + Source Files + Source Files @@ -74,4 +80,4 @@ Resource Files - \ No newline at end of file + From 4dd06525f881d8a19a5b1b6a1334db1d60d2dd32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sun, 28 Dec 2025 14:06:27 +0100 Subject: [PATCH 2/4] gh-142830: prevent crashes when replacing sqlite3 callbacks --- Lib/test/test_sqlite3/test_dbapi.py | 99 +++++++++++++++++++ ...-12-28-13-12-40.gh-issue-142830.uEyd6r.rst | 2 + Modules/_sqlite/connection.c | 41 ++++++-- 3 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 20e39f61e4dedb..5e07f42f898c79 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -2029,5 +2029,104 @@ def test_row_is_a_sequence(self): self.assertIsInstance(row, Sequence) +class CallbackTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.cx = sqlite.connect(":memory:") + self.addCleanup(self.cx.close) + self.cu = self.cx.cursor() + self.cu.execute("create table test(a number)") + + class Handler: + cx = self.cx + + self.handler_class = Handler + + def assert_not_authorized(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): + func(*args, **kwargs) + + def assert_interrupted(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): + func(*args, **kwargs) + + def assert_invalid_trace(self, func, /, *args, **kwargs): + # Exception in trace callbacks are entirely suppressed. + pass + + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_authorizer_invalid_signature(self): + self.cx.set_authorizer(lambda: None) + self.assert_not_authorized(self.cx.execute, "select * from test") + + def test_progress_handler_invalid_signature(self): + self.cx.set_progress_handler(lambda x: None, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_trace_callback_invalid_signature_traceback(self): + self.cx.set_trace_callback(lambda: None) + self.assert_invalid_trace(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + def test_authorizer_concurrent_mutation_in_call(self): + class Handler(self.handler_class): + def __call__(self, *a, **kw): + self.cx.set_authorizer(None) + raise ValueError + + self.cx.set_authorizer(Handler()) + self.assert_not_authorized(self.cx.execute, "select * from test") + + def test_authorizer_concurrent_mutation_with_overflown_value(self): + _testcapi = import_helper.import_module("_testcapi") + + class Handler(self.handler_class): + def __call__(self, *a, **kw): + self.cx.set_authorizer(None) + # We expect 'int' at the C level, so this one will raise + # when converting via PyLong_Int(). + return _testcapi.INT_MAX + 1 + + self.cx.set_authorizer(Handler()) + self.assert_not_authorized(self.cx.execute, "select * from test") + + def test_progress_handler_concurrent_mutation_in_call(self): + class Handler(self.handler_class): + def __call__(self, *a, **kw): + self.cx.set_authorizer(None) + raise ValueError + + self.cx.set_progress_handler(Handler(), 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_progress_handler_concurrent_mutation_in_conversion(self): + class Handler(self.handler_class): + def __bool__(self): + # clear the progress handler + self.cx.set_progress_handler(None, 1) + raise ValueError # force PyObject_True() to fail + + self.cx.set_progress_handler(Handler.__init__, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_trace_callback_concurrent_mutation_in_call(self): + class Handler: + def __call__(self, statement): + # clear the progress handler + self.cx.set_progress_handler(None, 1) + raise ValueError + + self.cx.set_trace_callback(Handler()) + self.assert_invalid_trace(self.cx.execute, "select * from test") + + # TODO(picnixz): increase test coverage for other callbacks + # such as 'func', 'step', 'finalize', and 'collation'. + + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst new file mode 100644 index 00000000000000..246979e91d76b5 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst @@ -0,0 +1,2 @@ +:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks +are mutated during a callback execution. Patch by Bénédikt Tran. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 53d1953b55e5b0..26e2cddbef9d6b 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -914,7 +914,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) if (args) { pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + Py_INCREF(ctx); py_retval = PyObject_CallObject(ctx->callable, args); + Py_DECREF(ctx); Py_DECREF(args); } @@ -942,6 +944,8 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (aggregate_instance == NULL) { @@ -971,6 +975,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } function_result = PyObject_CallObject(stepmethod, args); + Py_CLEAR(ctx); Py_DECREF(args); if (!function_result) { @@ -979,6 +984,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } error: + Py_XDECREF(ctx); Py_XDECREF(stepmethod); Py_XDECREF(function_result); @@ -1011,8 +1017,10 @@ final_callback(sqlite3_context *context) pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + Py_INCREF(ctx); function_result = PyObject_CallMethodNoArgs(*aggregate_instance, ctx->state->str_finalize); + Py_DECREF(ctx); Py_DECREF(*aggregate_instance); ok = 0; @@ -1163,6 +1171,8 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); int size = sizeof(PyObject *); PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); @@ -1191,9 +1201,11 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) "user-defined aggregate's 'inverse' method raised error"); goto exit; } + Py_CLEAR(ctx); Py_DECREF(res); exit: + Py_XDECREF(ctx); Py_XDECREF(method); PyGILState_Release(gilstate); } @@ -1217,7 +1229,10 @@ value_callback(sqlite3_context *context) assert(cls != NULL); assert(*cls != NULL); + Py_INCREF(ctx); PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value); + Py_DECREF(ctx); + if (res == NULL) { int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); set_sqlite_error(context, attr_err @@ -1360,10 +1375,11 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1, assert(ctx_vp != NULL); pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); - PyObject *callable = ctx->callable; - ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname, - access_attempt_source); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + ret = PyObject_CallFunction(ctx->callable, "issss", action, arg1, arg2, + dbname, access_attempt_source); if (ret == NULL) { print_or_clear_traceback(ctx); rc = SQLITE_DENY; @@ -1381,6 +1397,7 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1, } Py_DECREF(ret); } + Py_DECREF(ctx); PyGILState_Release(gilstate); return rc; @@ -1396,8 +1413,10 @@ progress_callback(void *ctx_vp) assert(ctx_vp != NULL); pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); - PyObject *callable = ctx->callable; - ret = PyObject_CallNoArgs(callable); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + + ret = PyObject_CallNoArgs(ctx->callable); if (!ret) { /* abort query if error occurred */ rc = -1; @@ -1409,7 +1428,7 @@ progress_callback(void *ctx_vp) if (rc < 0) { print_or_clear_traceback(ctx); } - + Py_DECREF(ctx); PyGILState_Release(gilstate); return rc; } @@ -1455,7 +1474,9 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) sqlite3_free((void *)expanded_sql); } if (py_statement) { + Py_INCREF(ctx); PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement); + Py_DECREF(ctx); Py_DECREF(py_statement); Py_XDECREF(ret); } @@ -1889,6 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data, { PyGILState_STATE gilstate = PyGILState_Ensure(); + pysqlite_CallbackContext *ctx = NULL; PyObject* string1 = 0; PyObject* string2 = 0; PyObject* retval = NULL; @@ -1910,8 +1932,11 @@ collation_callback(void *context, int text1_length, const void *text1_data, goto finally; } - pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context); + ctx = pysqlite_CallbackContext_CAST(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs. size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL); @@ -1931,8 +1956,10 @@ collation_callback(void *context, int text1_length, const void *text1_data, else if (longval < 0) result = -1; } + Py_CLEAR(ctx); finally: + Py_XDECREF(ctx); Py_XDECREF(string1); Py_XDECREF(string2); Py_XDECREF(retval); From f089fb6abdb293682270f3275b88d79c31c87b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Mon, 29 Dec 2025 19:04:39 +0100 Subject: [PATCH 3/4] fix more crashes and move test where they belong --- Lib/test/test_sqlite3/test_dbapi.py | 99 --------------------- Lib/test/test_sqlite3/test_hooks.py | 130 +++++++++++++++++++++++++++- Lib/test/test_sqlite3/util.py | 2 +- Modules/_sqlite/connection.c | 5 +- Modules/_sqlite/context.c | 4 +- Modules/_sqlite/context.h | 12 +-- 6 files changed, 139 insertions(+), 113 deletions(-) diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 5e07f42f898c79..20e39f61e4dedb 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -2029,104 +2029,5 @@ def test_row_is_a_sequence(self): self.assertIsInstance(row, Sequence) -class CallbackTests(unittest.TestCase): - - def setUp(self): - super().setUp() - self.cx = sqlite.connect(":memory:") - self.addCleanup(self.cx.close) - self.cu = self.cx.cursor() - self.cu.execute("create table test(a number)") - - class Handler: - cx = self.cx - - self.handler_class = Handler - - def assert_not_authorized(self, func, /, *args, **kwargs): - with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): - func(*args, **kwargs) - - def assert_interrupted(self, func, /, *args, **kwargs): - with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): - func(*args, **kwargs) - - def assert_invalid_trace(self, func, /, *args, **kwargs): - # Exception in trace callbacks are entirely suppressed. - pass - - # When a handler has an invalid signature, the exception raised is - # the same that would be raised if the handler "negatively" replied. - - def test_authorizer_invalid_signature(self): - self.cx.set_authorizer(lambda: None) - self.assert_not_authorized(self.cx.execute, "select * from test") - - def test_progress_handler_invalid_signature(self): - self.cx.set_progress_handler(lambda x: None, 1) - self.assert_interrupted(self.cx.execute, "select * from test") - - def test_trace_callback_invalid_signature_traceback(self): - self.cx.set_trace_callback(lambda: None) - self.assert_invalid_trace(self.cx.execute, "select * from test") - - # Tests for checking that callback context mutations do not crash. - # Regression tests for https://github.com/python/cpython/issues/142830. - - def test_authorizer_concurrent_mutation_in_call(self): - class Handler(self.handler_class): - def __call__(self, *a, **kw): - self.cx.set_authorizer(None) - raise ValueError - - self.cx.set_authorizer(Handler()) - self.assert_not_authorized(self.cx.execute, "select * from test") - - def test_authorizer_concurrent_mutation_with_overflown_value(self): - _testcapi = import_helper.import_module("_testcapi") - - class Handler(self.handler_class): - def __call__(self, *a, **kw): - self.cx.set_authorizer(None) - # We expect 'int' at the C level, so this one will raise - # when converting via PyLong_Int(). - return _testcapi.INT_MAX + 1 - - self.cx.set_authorizer(Handler()) - self.assert_not_authorized(self.cx.execute, "select * from test") - - def test_progress_handler_concurrent_mutation_in_call(self): - class Handler(self.handler_class): - def __call__(self, *a, **kw): - self.cx.set_authorizer(None) - raise ValueError - - self.cx.set_progress_handler(Handler(), 1) - self.assert_interrupted(self.cx.execute, "select * from test") - - def test_progress_handler_concurrent_mutation_in_conversion(self): - class Handler(self.handler_class): - def __bool__(self): - # clear the progress handler - self.cx.set_progress_handler(None, 1) - raise ValueError # force PyObject_True() to fail - - self.cx.set_progress_handler(Handler.__init__, 1) - self.assert_interrupted(self.cx.execute, "select * from test") - - def test_trace_callback_concurrent_mutation_in_call(self): - class Handler: - def __call__(self, statement): - # clear the progress handler - self.cx.set_progress_handler(None, 1) - raise ValueError - - self.cx.set_trace_callback(Handler()) - self.assert_invalid_trace(self.cx.execute, "select * from test") - - # TODO(picnixz): increase test coverage for other callbacks - # such as 'func', 'step', 'finalize', and 'collation'. - - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 2b907e35131d06..00d3bce5b7ee22 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -24,11 +24,15 @@ import sqlite3 as sqlite import unittest +from test.support import import_helper from test.support.os_helper import TESTFN, unlink from .util import memory_database, cx_limit, with_tracebacks from .util import MemoryDatabaseMixin +# TODO(picnixz): increase test coverage for other callbacks +# such as 'func', 'step', 'finalize', and 'collation'. + class CollationTests(MemoryDatabaseMixin, unittest.TestCase): @@ -129,8 +133,59 @@ def test_deregister_collation(self): self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') +class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase): + + def assert_not_authorized(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): + func(*args, **kwargs) + + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_authorizer_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_authorizer(lambda: None) + self.assert_not_authorized(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_authorizer_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + cx = self.cx + def __call__(self, *a, **kw): + self.cx.set_authorizer(None) + raise ZeroDivisionError("hello world") + + self.cx.set_authorizer(Handler()) + self.assert_not_authorized(self.cx.execute, "select * from test") + + @with_tracebacks(OverflowError) + def test_authorizer_concurrent_mutation_with_overflown_value(self): + _testcapi = import_helper.import_module("_testcapi") + self.cx.execute("create table if not exists test(a number)") + + class Handler: + cx = self.cx + def __call__(self, *a, **kw): + self.cx.set_authorizer(None) + # We expect 'int' at the C level, so this one will raise + # when converting via PyLong_Int(). + return _testcapi.INT_MAX + 1 + + self.cx.set_authorizer(Handler()) + self.assert_not_authorized(self.cx.execute, "select * from test") + + class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def assert_interrupted(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): + func(*args, **kwargs) + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. @@ -219,11 +274,51 @@ def bad_progress(): create table foo(a, b) """) - def test_progress_handler_keyword_args(self): + def test_set_progress_handler_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes at least 1 positional argument'): self.con.set_progress_handler(progress_handler=lambda: None, n=1) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_progress_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_progress_handler(lambda x: None, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_progress_handler_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + cx = self.cx + def __call__(self, *a, **kw): + self.cx.set_progress_handler(None, 1) + raise ZeroDivisionError("hello world") + + self.cx.set_progress_handler(Handler(), 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_progress_handler_concurrent_mutation_in_conversion(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + cx = self.cx + def __bool__(self): + # clear the progress handler + self.cx.set_progress_handler(None, 1) + raise ValueError # force PyObject_True() to fail + + self.cx.set_progress_handler(Handler.__init__, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Running with tracebacks makes the second execution of this + # function raise another exception because of a database change. + class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): @@ -345,11 +440,42 @@ def test_trace_bad_handler(self): cx.set_trace_callback(lambda stmt: 5/0) cx.execute("select 1") - def test_trace_keyword_args(self): + def test_set_trace_callback_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes exactly 1 positional argument'): self.con.set_trace_callback(trace_callback=lambda: None) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied, + # but for the trace handler, exceptions are never re-raised (only + # printed when needed). + + @with_tracebacks( + TypeError, + regex=r".*\(\) missing 6 required positional arguments", + ) + def test_trace_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None) + self.cx.execute("select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_trace_callback_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + cx = self.cx + def __call__(self, statement): + # clear the progress handler + self.cx.set_trace_callback(None) + raise ZeroDivisionError("hello world") + + self.cx.set_trace_callback(Handler()) + self.cx.execute("select * from test") + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py index cccd062160fd40..cd88d7579671d0 100644 --- a/Lib/test/test_sqlite3/util.py +++ b/Lib/test/test_sqlite3/util.py @@ -50,7 +50,7 @@ def check_tracebacks(self, cm, exc, exc_regex, msg_regex, obj_name): with contextlib.redirect_stderr(buf): yield - self.assertEqual(cm.unraisable.exc_type, exc) + self.assertIsSubclass(cm.unraisable.exc_type, exc) if exc_regex: msg = str(cm.unraisable.exc_value) self.assertIsNotNone(exc_regex.search(msg), (exc_regex, msg)) diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 26e2cddbef9d6b..d8527a2a79d32e 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1450,6 +1450,8 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) assert(ctx_vp != NULL); pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); pysqlite_state *state = ctx->state; assert(state != NULL); @@ -1474,9 +1476,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) sqlite3_free((void *)expanded_sql); } if (py_statement) { - Py_INCREF(ctx); PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement); - Py_DECREF(ctx); Py_DECREF(py_statement); Py_XDECREF(ret); } @@ -1485,6 +1485,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) } exit: + Py_DECREF(ctx); PyGILState_Release(gilstate); return 0; } diff --git a/Modules/_sqlite/context.c b/Modules/_sqlite/context.c index 01077eb33215aa..48f6886d11137f 100644 --- a/Modules/_sqlite/context.c +++ b/Modules/_sqlite/context.c @@ -49,7 +49,6 @@ callback_context_new_impl(PyTypeObject *type, PyObject *callable) ctx->state = state; PyObject_GC_Track(ctx); return (PyObject *)ctx; - } static int @@ -104,8 +103,7 @@ static PyType_Spec callback_context_spec = { PyObject * pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable) { - PyTypeObject *type = state->CallbackContextType; - return callback_context_new_impl(type, callable); + return callback_context_new_impl(state->CallbackContextType, callable); } int diff --git a/Modules/_sqlite/context.h b/Modules/_sqlite/context.h index 10f893ad50e382..9775b2cef46502 100644 --- a/Modules/_sqlite/context.h +++ b/Modules/_sqlite/context.h @@ -9,6 +9,12 @@ extern "C" { #include "module.h" +/* + * UDF/callback context structure. + * + * In order to ensure that the state pointer always outlives the callback + * context, we make sure it owns a reference to the module itself. + */ typedef struct { PyObject_HEAD PyObject *callable; @@ -18,12 +24,6 @@ typedef struct { #define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op)) -/* Allocate a UDF/callback context structure. In order to ensure that the state - * pointer always outlives the callback context, we make sure it owns a - * reference to the module itself. create_callback_context() is always called - * from connection methods, so we use the defining class to fetch the module - * pointer. - */ PyObject * pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable); From 80058035011c3edaaa9a49cd81d453d57e1f326a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Mon, 29 Dec 2025 20:44:32 +0100 Subject: [PATCH 4/4] simplify tests --- Lib/test/test_sqlite3/test_hooks.py | 51 ++++++++++++----------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 00d3bce5b7ee22..495ef97fa3c61c 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -154,13 +154,11 @@ def test_authorizer_invalid_signature(self): def test_authorizer_concurrent_mutation_in_call(self): self.cx.execute("create table if not exists test(a number)") - class Handler: - cx = self.cx - def __call__(self, *a, **kw): - self.cx.set_authorizer(None) - raise ZeroDivisionError("hello world") + def handler(*a, **kw): + self.cx.set_authorizer(None) + raise ZeroDivisionError("hello world") - self.cx.set_authorizer(Handler()) + self.cx.set_authorizer(handler) self.assert_not_authorized(self.cx.execute, "select * from test") @with_tracebacks(OverflowError) @@ -168,15 +166,13 @@ def test_authorizer_concurrent_mutation_with_overflown_value(self): _testcapi = import_helper.import_module("_testcapi") self.cx.execute("create table if not exists test(a number)") - class Handler: - cx = self.cx - def __call__(self, *a, **kw): - self.cx.set_authorizer(None) - # We expect 'int' at the C level, so this one will raise - # when converting via PyLong_Int(). - return _testcapi.INT_MAX + 1 - - self.cx.set_authorizer(Handler()) + def handler(*a, **kw): + self.cx.set_authorizer(None) + # We expect 'int' at the C level, so this one will raise + # when converting via PyLong_Int(). + return _testcapi.INT_MAX + 1 + + self.cx.set_authorizer(handler) self.assert_not_authorized(self.cx.execute, "select * from test") @@ -294,21 +290,18 @@ def test_progress_handler_invalid_signature(self): def test_progress_handler_concurrent_mutation_in_call(self): self.cx.execute("create table if not exists test(a number)") - class Handler: - cx = self.cx - def __call__(self, *a, **kw): - self.cx.set_progress_handler(None, 1) - raise ZeroDivisionError("hello world") + def handler(*a, **kw): + self.cx.set_progress_handler(None, 1) + raise ZeroDivisionError("hello world") - self.cx.set_progress_handler(Handler(), 1) + self.cx.set_progress_handler(handler, 1) self.assert_interrupted(self.cx.execute, "select * from test") def test_progress_handler_concurrent_mutation_in_conversion(self): self.cx.execute("create table if not exists test(a number)") class Handler: - cx = self.cx - def __bool__(self): + def __bool__(_): # clear the progress handler self.cx.set_progress_handler(None, 1) raise ValueError # force PyObject_True() to fail @@ -466,14 +459,12 @@ def test_trace_handler_invalid_signature(self): def test_trace_callback_concurrent_mutation_in_call(self): self.cx.execute("create table if not exists test(a number)") - class Handler: - cx = self.cx - def __call__(self, statement): - # clear the progress handler - self.cx.set_trace_callback(None) - raise ZeroDivisionError("hello world") + def handler(statement): + # clear the progress handler + self.cx.set_trace_callback(None) + raise ZeroDivisionError("hello world") - self.cx.set_trace_callback(Handler()) + self.cx.set_trace_callback(handler) self.cx.execute("select * from test")