From c2f25b51798c15b59300c0881c06262cab0c452d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?=
<10796600+picnixz@users.noreply.github.com>
Date: Sun, 28 Dec 2025 12:49:29 +0100
Subject: [PATCH 1/4] gh-142830: convert callback context into a full-fledged
`PyObject *`
---
Makefile.pre.in | 2 +-
Modules/Setup.stdlib.in | 2 +-
Modules/_sqlite/clinic/context.c.h | 30 +++++
Modules/_sqlite/connection.c | 192 ++++++++++-------------------
Modules/_sqlite/connection.h | 15 +--
Modules/_sqlite/context.c | 121 ++++++++++++++++++
Modules/_sqlite/context.h | 37 ++++++
Modules/_sqlite/module.c | 1 +
Modules/_sqlite/module.h | 1 +
PCbuild/_sqlite3.vcxproj | 4 +-
PCbuild/_sqlite3.vcxproj.filters | 8 +-
11 files changed, 275 insertions(+), 138 deletions(-)
create mode 100644 Modules/_sqlite/clinic/context.c.h
create mode 100644 Modules/_sqlite/context.c
create mode 100644 Modules/_sqlite/context.h
diff --git a/Makefile.pre.in b/Makefile.pre.in
index a6beb96d12a3f2..36c81b0337fddd 100644
--- a/Makefile.pre.in
+++ b/Makefile.pre.in
@@ -3429,7 +3429,7 @@ MODULE__SSL_DEPS=$(srcdir)/Modules/_ssl.h $(srcdir)/Modules/_ssl/cert.c $(srcdir
MODULE__TESTCAPI_DEPS=$(srcdir)/Modules/_testcapi/parts.h $(srcdir)/Modules/_testcapi/util.h
MODULE__TESTLIMITEDCAPI_DEPS=$(srcdir)/Modules/_testlimitedcapi/testcapi_long.h $(srcdir)/Modules/_testlimitedcapi/parts.h $(srcdir)/Modules/_testlimitedcapi/util.h
MODULE__TESTINTERNALCAPI_DEPS=$(srcdir)/Modules/_testinternalcapi/parts.h
-MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h
+MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/context.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h
MODULE__ZSTD_DEPS=$(srcdir)/Modules/_zstd/_zstdmodule.h $(srcdir)/Modules/_zstd/buffer.h $(srcdir)/Modules/_zstd/zstddict.h
CODECS_COMMON_HEADERS=$(srcdir)/Modules/cjkcodecs/multibytecodec.h $(srcdir)/Modules/cjkcodecs/cjkcodecs.h
diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in
index 2a4b937ce6bf80..59d649825fbe9d 100644
--- a/Modules/Setup.stdlib.in
+++ b/Modules/Setup.stdlib.in
@@ -149,7 +149,7 @@
# needs -lncurses[w] and -lpanel[w]
@MODULE__CURSES_PANEL_TRUE@_curses_panel _curses_panel.c
-@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c
+@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/context.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c
# needs -lssl and -lcrypt
@MODULE__SSL_TRUE@_ssl _ssl.c
diff --git a/Modules/_sqlite/clinic/context.c.h b/Modules/_sqlite/clinic/context.c.h
new file mode 100644
index 00000000000000..9788f819ac7f68
--- /dev/null
+++ b/Modules/_sqlite/clinic/context.c.h
@@ -0,0 +1,30 @@
+/*[clinic input]
+preserve
+[clinic start generated code]*/
+
+#include "pycore_modsupport.h" // _PyArg_CheckPositional()
+
+static PyObject *
+callback_context_new_impl(PyTypeObject *type, PyObject *callable);
+
+static PyObject *
+callback_context_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
+{
+ PyObject *return_value = NULL;
+ PyTypeObject *base_tp = clinic_state()->CallbackContextType;
+ PyObject *callable;
+
+ if ((type == base_tp || type->tp_init == base_tp->tp_init) &&
+ !_PyArg_NoKeywords("_CallbackContext", kwargs)) {
+ goto exit;
+ }
+ if (!_PyArg_CheckPositional("_CallbackContext", PyTuple_GET_SIZE(args), 1, 1)) {
+ goto exit;
+ }
+ callable = PyTuple_GET_ITEM(args, 0);
+ return_value = callback_context_new_impl(type, callable);
+
+exit:
+ return return_value;
+}
+/*[clinic end generated code: output=370246c27daeaaa1 input=a9049054013a1b77]*/
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 83ff8e60557c07..53d1953b55e5b0 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -145,9 +145,6 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/
static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
-static void free_callback_context(callback_context *ctx);
-static void set_callback_context(callback_context **ctx_pp,
- callback_context *ctx);
static int connection_close(pysqlite_Connection *self);
PyObject *_pysqlite_query_execute(pysqlite_Cursor *, int, PyObject *, PyObject *);
@@ -377,13 +374,13 @@ output pop
[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=b899ba9273edcce7]*/
-#define VISIT_CALLBACK_CONTEXT(ctx) \
-do { \
- if (ctx) { \
- Py_VISIT(ctx->callable); \
- Py_VISIT(ctx->module); \
- } \
-} while (0)
+static void
+clear_callback_contexts(pysqlite_Connection *self)
+{
+ Py_CLEAR(self->trace_ctx);
+ Py_CLEAR(self->progress_ctx);
+ Py_CLEAR(self->authorizer_ctx);
+}
static int
connection_traverse(PyObject *op, visitproc visit, void *arg)
@@ -395,22 +392,12 @@ connection_traverse(PyObject *op, visitproc visit, void *arg)
Py_VISIT(self->blobs);
Py_VISIT(self->row_factory);
Py_VISIT(self->text_factory);
- VISIT_CALLBACK_CONTEXT(self->trace_ctx);
- VISIT_CALLBACK_CONTEXT(self->progress_ctx);
- VISIT_CALLBACK_CONTEXT(self->authorizer_ctx);
-#undef VISIT_CALLBACK_CONTEXT
+ Py_VISIT(self->trace_ctx);
+ Py_VISIT(self->progress_ctx);
+ Py_VISIT(self->authorizer_ctx);
return 0;
}
-static inline void
-clear_callback_context(callback_context *ctx)
-{
- if (ctx != NULL) {
- Py_CLEAR(ctx->callable);
- Py_CLEAR(ctx->module);
- }
-}
-
static int
connection_clear(PyObject *op)
{
@@ -420,20 +407,10 @@ connection_clear(PyObject *op)
Py_CLEAR(self->blobs);
Py_CLEAR(self->row_factory);
Py_CLEAR(self->text_factory);
- clear_callback_context(self->trace_ctx);
- clear_callback_context(self->progress_ctx);
- clear_callback_context(self->authorizer_ctx);
+ clear_callback_contexts(self);
return 0;
}
-static void
-free_callback_contexts(pysqlite_Connection *self)
-{
- set_callback_context(&self->trace_ctx, NULL);
- set_callback_context(&self->progress_ctx, NULL);
- set_callback_context(&self->authorizer_ctx, NULL);
-}
-
static void
remove_callbacks(sqlite3 *db)
{
@@ -474,7 +451,7 @@ connection_close(pysqlite_Connection *self)
(void)sqlite3_close_v2(db);
Py_END_ALLOW_THREADS
- free_callback_contexts(self);
+ clear_callback_contexts(self);
return rc;
}
@@ -814,7 +791,7 @@ _pysqlite_set_result(sqlite3_context* context, PyObject* py_val)
sqlite3_result_blob(context, view.buf, (int)view.len, SQLITE_TRANSIENT);
PyBuffer_Release(&view);
} else {
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
PyErr_Format(ctx->state->ProgrammingError,
"User-defined functions cannot return '%s' values to "
"SQLite",
@@ -893,7 +870,7 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc,
}
static void
-print_or_clear_traceback(callback_context *ctx)
+print_or_clear_traceback(pysqlite_CallbackContext *ctx)
{
assert(ctx != NULL);
assert(ctx->state != NULL);
@@ -920,7 +897,7 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
else {
sqlite3_result_error(context, msg, -1);
}
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
print_or_clear_traceback(ctx);
}
@@ -935,7 +912,7 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
args = _pysqlite_build_py_params(context, argc, argv);
if (args) {
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
py_retval = PyObject_CallObject(ctx->callable, args);
Py_DECREF(args);
@@ -963,7 +940,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
PyObject** aggregate_instance;
PyObject* stepmethod = NULL;
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
@@ -1032,7 +1009,7 @@ final_callback(sqlite3_context *context)
// Keep the exception (if any) of the last call to step, value, or inverse
PyObject *exc = PyErr_GetRaisedException();
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
ctx->state->str_finalize);
@@ -1095,55 +1072,14 @@ _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
return 0;
}
-/* Allocate a UDF/callback context structure. In order to ensure that the state
- * pointer always outlives the callback context, we make sure it owns a
- * reference to the module itself. create_callback_context() is always called
- * from connection methods, so we use the defining class to fetch the module
- * pointer.
- */
-static callback_context *
-create_callback_context(PyTypeObject *cls, PyObject *callable)
-{
- callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
- if (ctx != NULL) {
- PyObject *module = PyType_GetModule(cls);
- ctx->callable = Py_NewRef(callable);
- ctx->module = Py_NewRef(module);
- ctx->state = pysqlite_get_state(module);
- }
- return ctx;
-}
-
-static void
-free_callback_context(callback_context *ctx)
-{
- assert(ctx != NULL);
- Py_XDECREF(ctx->callable);
- Py_XDECREF(ctx->module);
- PyMem_Free(ctx);
-}
-
-static void
-set_callback_context(callback_context **ctx_pp, callback_context *ctx)
-{
- assert(ctx_pp != NULL);
- callback_context *tmp = *ctx_pp;
- *ctx_pp = ctx;
- if (tmp != NULL) {
- free_callback_context(tmp);
- }
-}
-
static void
destructor_callback(void *ctx)
{
- if (ctx != NULL) {
- // This function may be called without the GIL held, so we need to
- // ensure that we destroy 'ctx' with the GIL held.
- PyGILState_STATE gstate = PyGILState_Ensure();
- free_callback_context((callback_context *)ctx);
- PyGILState_Release(gstate);
- }
+ // This function may be called without the GIL held, so we need to
+ // ensure that we destroy 'ctx' with the GIL held.
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ Py_DECREF(ctx);
+ PyGILState_Release(gstate);
}
static int
@@ -1194,7 +1130,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
if (deterministic) {
flags |= SQLITE_DETERMINISTIC;
}
- callback_context *ctx = create_callback_context(cls, func);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, func);
if (ctx == NULL) {
return NULL;
}
@@ -1225,7 +1161,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
int size = sizeof(PyObject *);
@@ -1273,7 +1209,7 @@ value_callback(sqlite3_context *context)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
int size = sizeof(PyObject *);
@@ -1344,7 +1280,7 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
0, 0, 0, 0, 0, 0);
}
else {
- callback_context *ctx = create_callback_context(cls, aggregate_class);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class);
if (ctx == NULL) {
return NULL;
}
@@ -1395,7 +1331,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
return NULL;
}
- callback_context *ctx = create_callback_context(cls, aggregate_class);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class);
if (ctx == NULL) {
return NULL;
}
@@ -1413,7 +1349,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
}
static int
-authorizer_callback(void *ctx, int action, const char *arg1,
+authorizer_callback(void *ctx_vp, int action, const char *arg1,
const char *arg2 , const char *dbname,
const char *access_attempt_source)
{
@@ -1422,8 +1358,9 @@ authorizer_callback(void *ctx, int action, const char *arg1,
PyObject *ret;
int rc = SQLITE_DENY;
- assert(ctx != NULL);
- PyObject *callable = ((callback_context *)ctx)->callable;
+ assert(ctx_vp != NULL);
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
+ PyObject *callable = ctx->callable;
ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
access_attempt_source);
@@ -1450,15 +1387,16 @@ authorizer_callback(void *ctx, int action, const char *arg1,
}
static int
-progress_callback(void *ctx)
+progress_callback(void *ctx_vp)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
int rc;
PyObject *ret;
- assert(ctx != NULL);
- PyObject *callable = ((callback_context *)ctx)->callable;
+ assert(ctx_vp != NULL);
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
+ PyObject *callable = ctx->callable;
ret = PyObject_CallNoArgs(callable);
if (!ret) {
/* abort query if error occurred */
@@ -1483,7 +1421,7 @@ progress_callback(void *ctx)
* to ensure future compatibility.
*/
static int
-trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
+trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
{
if (type != SQLITE_TRACE_STMT) {
return 0;
@@ -1491,8 +1429,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
PyGILState_STATE gilstate = PyGILState_Ensure();
- assert(ctx != NULL);
- pysqlite_state *state = ((callback_context *)ctx)->state;
+ assert(ctx_vp != NULL);
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
+ pysqlite_state *state = ctx->state;
assert(state != NULL);
PyObject *py_statement = NULL;
@@ -1506,7 +1445,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string length");
- print_or_clear_traceback((callback_context *)ctx);
+ print_or_clear_traceback(ctx);
// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
@@ -1516,13 +1455,12 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
sqlite3_free((void *)expanded_sql);
}
if (py_statement) {
- PyObject *callable = ((callback_context *)ctx)->callable;
- PyObject *ret = PyObject_CallOneArg(callable, py_statement);
+ PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
if (PyErr_Occurred()) {
- print_or_clear_traceback((callback_context *)ctx);
+ print_or_clear_traceback(ctx);
}
exit:
@@ -1552,21 +1490,22 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self,
int rc;
if (callable == Py_None) {
+ /* None clears the authorizer callback previously set */
rc = sqlite3_set_authorizer(self->db, NULL, NULL);
- set_callback_context(&self->authorizer_ctx, NULL);
+ Py_CLEAR(self->authorizer_ctx);
}
else {
- callback_context *ctx = create_callback_context(cls, callable);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_set_authorizer(self->db, authorizer_callback, ctx);
- set_callback_context(&self->authorizer_ctx, ctx);
+ Py_XSETREF(self->authorizer_ctx, ctx);
}
if (rc != SQLITE_OK) {
PyErr_SetString(self->OperationalError,
"Error setting authorizer callback");
- set_callback_context(&self->authorizer_ctx, NULL);
+ Py_CLEAR(self->authorizer_ctx);
return NULL;
}
Py_RETURN_NONE;
@@ -1604,15 +1543,15 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self,
if (callable == Py_None) {
/* None clears the progress handler previously set */
sqlite3_progress_handler(self->db, 0, 0, (void*)0);
- set_callback_context(&self->progress_ctx, NULL);
+ Py_CLEAR(self->progress_ctx);
}
else {
- callback_context *ctx = create_callback_context(cls, callable);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
sqlite3_progress_handler(self->db, n, progress_callback, ctx);
- set_callback_context(&self->progress_ctx, ctx);
+ Py_XSETREF(self->progress_ctx, ctx);
}
Py_RETURN_NONE;
}
@@ -1637,6 +1576,7 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
return NULL;
}
+ int rc;
if (callable == Py_None) {
/*
* None clears the trace callback previously set
@@ -1645,18 +1585,22 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self,
* - https://sqlite.org/c3ref/c_trace.html
* - https://sqlite.org/c3ref/trace_v2.html
*/
- sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0);
- set_callback_context(&self->trace_ctx, NULL);
+ rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0);
+ Py_CLEAR(self->trace_ctx);
}
else {
- callback_context *ctx = create_callback_context(cls, callable);
+ PyObject *ctx = pysqlite_create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
- sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx);
- set_callback_context(&self->trace_ctx, ctx);
+ rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx);
+ Py_XSETREF(self->trace_ctx, ctx);
+ }
+ if (rc != SQLITE_OK) {
+ PyErr_SetString(self->OperationalError, "Error setting trace callback");
+ Py_CLEAR(self->trace_ctx);
+ return NULL;
}
-
Py_RETURN_NONE;
}
@@ -1966,7 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
goto finally;
}
- callback_context *ctx = (callback_context *)context;
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context);
assert(ctx != NULL);
PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
@@ -2185,7 +2129,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
return NULL;
}
- callback_context *ctx = NULL;
+ PyObject *ctx = NULL;
int rc;
int flags = SQLITE_UTF8;
if (callable == Py_None) {
@@ -2193,11 +2137,13 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
NULL, NULL, NULL);
}
else {
+ // TODO(picnixz): defer this check to the context's constructor
+ // and do it for all other functions that create a context.
if (!PyCallable_Check(callable)) {
PyErr_SetString(PyExc_TypeError, "parameter must be callable");
return NULL;
}
- ctx = create_callback_context(cls, callable);
+ ctx = pysqlite_create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
@@ -2211,9 +2157,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
* called if sqlite3_create_collation_v2() fails, so we have to free
* the context before returning.
*/
- if (callable != Py_None) {
- free_callback_context(ctx);
- }
+ Py_XDECREF(ctx);
set_error_from_db(self->state, self->db);
return NULL;
}
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index 7a748ee3ea0c58..20c20d772befa3 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -23,21 +23,16 @@
#ifndef PYSQLITE_CONNECTION_H
#define PYSQLITE_CONNECTION_H
+
#include "Python.h"
#include "pythread.h"
#include "structmember.h"
+#include "context.h"
#include "module.h"
#include "sqlite3.h"
-typedef struct _callback_context
-{
- PyObject *callable;
- PyObject *module;
- pysqlite_state *state;
-} callback_context;
-
enum autocommit_mode {
AUTOCOMMIT_LEGACY = LEGACY_TRANSACTION_CONTROL,
AUTOCOMMIT_ENABLED = 1,
@@ -88,9 +83,9 @@ typedef struct
PyObject* text_factory;
// Remember contexts used by the trace, progress, and authoriser callbacks
- callback_context *trace_ctx;
- callback_context *progress_ctx;
- callback_context *authorizer_ctx;
+ PyObject *trace_ctx;
+ PyObject *progress_ctx;
+ PyObject *authorizer_ctx;
/* Exception objects: borrowed refs. */
PyObject* Warning;
diff --git a/Modules/_sqlite/context.c b/Modules/_sqlite/context.c
new file mode 100644
index 00000000000000..01077eb33215aa
--- /dev/null
+++ b/Modules/_sqlite/context.c
@@ -0,0 +1,121 @@
+#ifndef Py_BUILD_CORE_BUILTIN
+# define Py_BUILD_CORE_MODULE 1
+#endif
+
+#include "context.h"
+#include "module.h"
+
+#define clinic_state() (pysqlite_get_state_by_type(type))
+#include "clinic/context.c.h"
+#undef clinic_state
+
+/*[clinic input]
+module _sqlite3
+class _sqlite3._CallbackContext "pysqlite_CallbackContext *" "clinic_state()->CallbackContextType"
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=f06c18b7c3666e05]*/
+
+/*[clinic input]
+@classmethod
+_sqlite3._CallbackContext.__new__ as callback_context_new
+
+ callable: object
+ /
+
+[clinic start generated code]*/
+
+static PyObject *
+callback_context_new_impl(PyTypeObject *type, PyObject *callable)
+/*[clinic end generated code: output=46b4f355475d88cc input=ae33656b48f65a6d]*/
+{
+ PyObject *module = PyType_GetModuleByDef(type, &_sqlite3module);
+ assert(module != NULL);
+ pysqlite_state *state = pysqlite_get_state_by_type(type);
+ assert(state != NULL);
+
+ pysqlite_CallbackContext *ctx = PyObject_GC_New(
+ pysqlite_CallbackContext,
+ state->CallbackContextType
+ );
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ // TODO(picnixz): check that 'callable' is effectively callable
+ // instead of relying on a generic TypeError when attempting to
+ // call it.
+ ctx->callable = Py_NewRef(callable);
+ ctx->module = Py_NewRef(module);
+ ctx->state = state;
+ PyObject_GC_Track(ctx);
+ return (PyObject *)ctx;
+
+}
+
+static int
+callback_context_clear(PyObject *op)
+{
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op);
+ Py_CLEAR(ctx->callable);
+ Py_CLEAR(ctx->module);
+ return 0;
+}
+
+static void
+callback_context_dealloc(PyObject *op)
+{
+ PyTypeObject *type = Py_TYPE(op);
+ PyObject_GC_UnTrack(op);
+ (void)type->tp_clear(op);
+ type->tp_free(op);
+ Py_DECREF(type);
+}
+
+static int
+callback_context_traverse(PyObject *op, visitproc visit, void *arg)
+{
+ pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op);
+ Py_VISIT(Py_TYPE(op));
+ Py_VISIT(ctx->callable);
+ Py_VISIT(ctx->module);
+ return 0;
+}
+
+static PyType_Slot callback_context_slots[] = {
+ {Py_tp_new, callback_context_new},
+ {Py_tp_clear, callback_context_clear},
+ {Py_tp_dealloc, callback_context_dealloc},
+ {Py_tp_traverse, callback_context_traverse},
+ {0, NULL},
+};
+
+static PyType_Spec callback_context_spec = {
+ .name = MODULE_NAME "._CallbackContext",
+ .basicsize = sizeof(pysqlite_CallbackContext),
+ .flags = (
+ Py_TPFLAGS_DEFAULT
+ | Py_TPFLAGS_DISALLOW_INSTANTIATION
+ | Py_TPFLAGS_IMMUTABLETYPE
+ | Py_TPFLAGS_HAVE_GC
+ ),
+ .slots = callback_context_slots,
+};
+
+PyObject *
+pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable)
+{
+ PyTypeObject *type = state->CallbackContextType;
+ return callback_context_new_impl(type, callable);
+}
+
+int
+pysqlite_context_setup_types(PyObject *module)
+{
+ PyObject *type = PyType_FromModuleAndSpec(module, &callback_context_spec, NULL);
+ if (type == NULL) {
+ return -1;
+ }
+ pysqlite_state *state = pysqlite_get_state(module);
+ state->CallbackContextType = (PyTypeObject *)type;
+ return 0;
+}
diff --git a/Modules/_sqlite/context.h b/Modules/_sqlite/context.h
new file mode 100644
index 00000000000000..10f893ad50e382
--- /dev/null
+++ b/Modules/_sqlite/context.h
@@ -0,0 +1,37 @@
+#ifndef PYSQLITE_CALLBACK_CONTEXT_H
+#define PYSQLITE_CALLBACK_CONTEXT_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include "Python.h"
+
+#include "module.h"
+
+typedef struct {
+ PyObject_HEAD
+ PyObject *callable;
+ PyObject *module;
+ pysqlite_state *state;
+} pysqlite_CallbackContext;
+
+#define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op))
+
+/* Allocate a UDF/callback context structure. In order to ensure that the state
+ * pointer always outlives the callback context, we make sure it owns a
+ * reference to the module itself. create_callback_context() is always called
+ * from connection methods, so we use the defining class to fetch the module
+ * pointer.
+ */
+PyObject *
+pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable);
+
+int
+pysqlite_context_setup_types(PyObject *module);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // !PYSQLITE_CALLBACK_CONTEXT_H
diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c
index 831dd9219f77ab..923268159a4197 100644
--- a/Modules/_sqlite/module.c
+++ b/Modules/_sqlite/module.c
@@ -686,6 +686,7 @@ module_exec(PyObject *module)
if ((pysqlite_row_setup_types(module) < 0) ||
(pysqlite_cursor_setup_types(module) < 0) ||
+ (pysqlite_context_setup_types(module) < 0) ||
(pysqlite_connection_setup_types(module) < 0) ||
(pysqlite_statement_setup_types(module) < 0) ||
(pysqlite_prepare_protocol_setup_types(module) < 0) ||
diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h
index a4ca45cf6326a9..4fee440146c7f6 100644
--- a/Modules/_sqlite/module.h
+++ b/Modules/_sqlite/module.h
@@ -55,6 +55,7 @@ typedef struct {
int enable_callback_tracebacks;
PyTypeObject *BlobType;
+ PyTypeObject *CallbackContextType;
PyTypeObject *ConnectionType;
PyTypeObject *CursorType;
PyTypeObject *PrepareProtocolType;
diff --git a/PCbuild/_sqlite3.vcxproj b/PCbuild/_sqlite3.vcxproj
index 9ae0a0fc3a009d..4919b28293875e 100644
--- a/PCbuild/_sqlite3.vcxproj
+++ b/PCbuild/_sqlite3.vcxproj
@@ -99,6 +99,7 @@
+
@@ -110,6 +111,7 @@
+
@@ -135,4 +137,4 @@
-
\ No newline at end of file
+
diff --git a/PCbuild/_sqlite3.vcxproj.filters b/PCbuild/_sqlite3.vcxproj.filters
index f4a265eba7dd80..6a930a01d2e519 100644
--- a/PCbuild/_sqlite3.vcxproj.filters
+++ b/PCbuild/_sqlite3.vcxproj.filters
@@ -15,6 +15,9 @@
Header Files
+
+ Header Files
+
Header Files
@@ -44,6 +47,9 @@
Source Files
+
+ Source Files
+
Source Files
@@ -74,4 +80,4 @@
Resource Files
-
\ No newline at end of file
+
From 4dd06525f881d8a19a5b1b6a1334db1d60d2dd32 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?=
<10796600+picnixz@users.noreply.github.com>
Date: Sun, 28 Dec 2025 14:06:27 +0100
Subject: [PATCH 2/4] gh-142830: prevent crashes when replacing sqlite3
callbacks
---
Lib/test/test_sqlite3/test_dbapi.py | 99 +++++++++++++++++++
...-12-28-13-12-40.gh-issue-142830.uEyd6r.rst | 2 +
Modules/_sqlite/connection.c | 41 ++++++--
3 files changed, 135 insertions(+), 7 deletions(-)
create mode 100644 Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 20e39f61e4dedb..5e07f42f898c79 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -2029,5 +2029,104 @@ def test_row_is_a_sequence(self):
self.assertIsInstance(row, Sequence)
+class CallbackTests(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.cx = sqlite.connect(":memory:")
+ self.addCleanup(self.cx.close)
+ self.cu = self.cx.cursor()
+ self.cu.execute("create table test(a number)")
+
+ class Handler:
+ cx = self.cx
+
+ self.handler_class = Handler
+
+ def assert_not_authorized(self, func, /, *args, **kwargs):
+ with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
+ func(*args, **kwargs)
+
+ def assert_interrupted(self, func, /, *args, **kwargs):
+ with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
+ func(*args, **kwargs)
+
+ def assert_invalid_trace(self, func, /, *args, **kwargs):
+ # Exception in trace callbacks are entirely suppressed.
+ pass
+
+ # When a handler has an invalid signature, the exception raised is
+ # the same that would be raised if the handler "negatively" replied.
+
+ def test_authorizer_invalid_signature(self):
+ self.cx.set_authorizer(lambda: None)
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+ def test_progress_handler_invalid_signature(self):
+ self.cx.set_progress_handler(lambda x: None, 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ def test_trace_callback_invalid_signature_traceback(self):
+ self.cx.set_trace_callback(lambda: None)
+ self.assert_invalid_trace(self.cx.execute, "select * from test")
+
+ # Tests for checking that callback context mutations do not crash.
+ # Regression tests for https://github.com/python/cpython/issues/142830.
+
+ def test_authorizer_concurrent_mutation_in_call(self):
+ class Handler(self.handler_class):
+ def __call__(self, *a, **kw):
+ self.cx.set_authorizer(None)
+ raise ValueError
+
+ self.cx.set_authorizer(Handler())
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+ def test_authorizer_concurrent_mutation_with_overflown_value(self):
+ _testcapi = import_helper.import_module("_testcapi")
+
+ class Handler(self.handler_class):
+ def __call__(self, *a, **kw):
+ self.cx.set_authorizer(None)
+ # We expect 'int' at the C level, so this one will raise
+ # when converting via PyLong_Int().
+ return _testcapi.INT_MAX + 1
+
+ self.cx.set_authorizer(Handler())
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+ def test_progress_handler_concurrent_mutation_in_call(self):
+ class Handler(self.handler_class):
+ def __call__(self, *a, **kw):
+ self.cx.set_authorizer(None)
+ raise ValueError
+
+ self.cx.set_progress_handler(Handler(), 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ def test_progress_handler_concurrent_mutation_in_conversion(self):
+ class Handler(self.handler_class):
+ def __bool__(self):
+ # clear the progress handler
+ self.cx.set_progress_handler(None, 1)
+ raise ValueError # force PyObject_True() to fail
+
+ self.cx.set_progress_handler(Handler.__init__, 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ def test_trace_callback_concurrent_mutation_in_call(self):
+ class Handler:
+ def __call__(self, statement):
+ # clear the progress handler
+ self.cx.set_progress_handler(None, 1)
+ raise ValueError
+
+ self.cx.set_trace_callback(Handler())
+ self.assert_invalid_trace(self.cx.execute, "select * from test")
+
+ # TODO(picnixz): increase test coverage for other callbacks
+ # such as 'func', 'step', 'finalize', and 'collation'.
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
new file mode 100644
index 00000000000000..246979e91d76b5
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst
@@ -0,0 +1,2 @@
+:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
+are mutated during a callback execution. Patch by Bénédikt Tran.
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 53d1953b55e5b0..26e2cddbef9d6b 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -914,7 +914,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
if (args) {
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
+ Py_INCREF(ctx);
py_retval = PyObject_CallObject(ctx->callable, args);
+ Py_DECREF(ctx);
Py_DECREF(args);
}
@@ -942,6 +944,8 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (aggregate_instance == NULL) {
@@ -971,6 +975,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}
function_result = PyObject_CallObject(stepmethod, args);
+ Py_CLEAR(ctx);
Py_DECREF(args);
if (!function_result) {
@@ -979,6 +984,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}
error:
+ Py_XDECREF(ctx);
Py_XDECREF(stepmethod);
Py_XDECREF(function_result);
@@ -1011,8 +1017,10 @@ final_callback(sqlite3_context *context)
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
+ Py_INCREF(ctx);
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
ctx->state->str_finalize);
+ Py_DECREF(ctx);
Py_DECREF(*aggregate_instance);
ok = 0;
@@ -1163,6 +1171,8 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@@ -1191,9 +1201,11 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
"user-defined aggregate's 'inverse' method raised error");
goto exit;
}
+ Py_CLEAR(ctx);
Py_DECREF(res);
exit:
+ Py_XDECREF(ctx);
Py_XDECREF(method);
PyGILState_Release(gilstate);
}
@@ -1217,7 +1229,10 @@ value_callback(sqlite3_context *context)
assert(cls != NULL);
assert(*cls != NULL);
+ Py_INCREF(ctx);
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
+ Py_DECREF(ctx);
+
if (res == NULL) {
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
set_sqlite_error(context, attr_err
@@ -1360,10 +1375,11 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,
assert(ctx_vp != NULL);
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
- PyObject *callable = ctx->callable;
- ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
- access_attempt_source);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
+ ret = PyObject_CallFunction(ctx->callable, "issss", action, arg1, arg2,
+ dbname, access_attempt_source);
if (ret == NULL) {
print_or_clear_traceback(ctx);
rc = SQLITE_DENY;
@@ -1381,6 +1397,7 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,
}
Py_DECREF(ret);
}
+ Py_DECREF(ctx);
PyGILState_Release(gilstate);
return rc;
@@ -1396,8 +1413,10 @@ progress_callback(void *ctx_vp)
assert(ctx_vp != NULL);
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
- PyObject *callable = ctx->callable;
- ret = PyObject_CallNoArgs(callable);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
+
+ ret = PyObject_CallNoArgs(ctx->callable);
if (!ret) {
/* abort query if error occurred */
rc = -1;
@@ -1409,7 +1428,7 @@ progress_callback(void *ctx_vp)
if (rc < 0) {
print_or_clear_traceback(ctx);
}
-
+ Py_DECREF(ctx);
PyGILState_Release(gilstate);
return rc;
}
@@ -1455,7 +1474,9 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
sqlite3_free((void *)expanded_sql);
}
if (py_statement) {
+ Py_INCREF(ctx);
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
+ Py_DECREF(ctx);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
@@ -1889,6 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
{
PyGILState_STATE gilstate = PyGILState_Ensure();
+ pysqlite_CallbackContext *ctx = NULL;
PyObject* string1 = 0;
PyObject* string2 = 0;
PyObject* retval = NULL;
@@ -1910,8 +1932,11 @@ collation_callback(void *context, int text1_length, const void *text1_data,
goto finally;
}
- pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context);
+ ctx = pysqlite_CallbackContext_CAST(context);
assert(ctx != NULL);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
+
PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
@@ -1931,8 +1956,10 @@ collation_callback(void *context, int text1_length, const void *text1_data,
else if (longval < 0)
result = -1;
}
+ Py_CLEAR(ctx);
finally:
+ Py_XDECREF(ctx);
Py_XDECREF(string1);
Py_XDECREF(string2);
Py_XDECREF(retval);
From f089fb6abdb293682270f3275b88d79c31c87b6c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?=
<10796600+picnixz@users.noreply.github.com>
Date: Mon, 29 Dec 2025 19:04:39 +0100
Subject: [PATCH 3/4] fix more crashes and move test where they belong
---
Lib/test/test_sqlite3/test_dbapi.py | 99 ---------------------
Lib/test/test_sqlite3/test_hooks.py | 130 +++++++++++++++++++++++++++-
Lib/test/test_sqlite3/util.py | 2 +-
Modules/_sqlite/connection.c | 5 +-
Modules/_sqlite/context.c | 4 +-
Modules/_sqlite/context.h | 12 +--
6 files changed, 139 insertions(+), 113 deletions(-)
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 5e07f42f898c79..20e39f61e4dedb 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -2029,104 +2029,5 @@ def test_row_is_a_sequence(self):
self.assertIsInstance(row, Sequence)
-class CallbackTests(unittest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.cx = sqlite.connect(":memory:")
- self.addCleanup(self.cx.close)
- self.cu = self.cx.cursor()
- self.cu.execute("create table test(a number)")
-
- class Handler:
- cx = self.cx
-
- self.handler_class = Handler
-
- def assert_not_authorized(self, func, /, *args, **kwargs):
- with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
- func(*args, **kwargs)
-
- def assert_interrupted(self, func, /, *args, **kwargs):
- with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
- func(*args, **kwargs)
-
- def assert_invalid_trace(self, func, /, *args, **kwargs):
- # Exception in trace callbacks are entirely suppressed.
- pass
-
- # When a handler has an invalid signature, the exception raised is
- # the same that would be raised if the handler "negatively" replied.
-
- def test_authorizer_invalid_signature(self):
- self.cx.set_authorizer(lambda: None)
- self.assert_not_authorized(self.cx.execute, "select * from test")
-
- def test_progress_handler_invalid_signature(self):
- self.cx.set_progress_handler(lambda x: None, 1)
- self.assert_interrupted(self.cx.execute, "select * from test")
-
- def test_trace_callback_invalid_signature_traceback(self):
- self.cx.set_trace_callback(lambda: None)
- self.assert_invalid_trace(self.cx.execute, "select * from test")
-
- # Tests for checking that callback context mutations do not crash.
- # Regression tests for https://github.com/python/cpython/issues/142830.
-
- def test_authorizer_concurrent_mutation_in_call(self):
- class Handler(self.handler_class):
- def __call__(self, *a, **kw):
- self.cx.set_authorizer(None)
- raise ValueError
-
- self.cx.set_authorizer(Handler())
- self.assert_not_authorized(self.cx.execute, "select * from test")
-
- def test_authorizer_concurrent_mutation_with_overflown_value(self):
- _testcapi = import_helper.import_module("_testcapi")
-
- class Handler(self.handler_class):
- def __call__(self, *a, **kw):
- self.cx.set_authorizer(None)
- # We expect 'int' at the C level, so this one will raise
- # when converting via PyLong_Int().
- return _testcapi.INT_MAX + 1
-
- self.cx.set_authorizer(Handler())
- self.assert_not_authorized(self.cx.execute, "select * from test")
-
- def test_progress_handler_concurrent_mutation_in_call(self):
- class Handler(self.handler_class):
- def __call__(self, *a, **kw):
- self.cx.set_authorizer(None)
- raise ValueError
-
- self.cx.set_progress_handler(Handler(), 1)
- self.assert_interrupted(self.cx.execute, "select * from test")
-
- def test_progress_handler_concurrent_mutation_in_conversion(self):
- class Handler(self.handler_class):
- def __bool__(self):
- # clear the progress handler
- self.cx.set_progress_handler(None, 1)
- raise ValueError # force PyObject_True() to fail
-
- self.cx.set_progress_handler(Handler.__init__, 1)
- self.assert_interrupted(self.cx.execute, "select * from test")
-
- def test_trace_callback_concurrent_mutation_in_call(self):
- class Handler:
- def __call__(self, statement):
- # clear the progress handler
- self.cx.set_progress_handler(None, 1)
- raise ValueError
-
- self.cx.set_trace_callback(Handler())
- self.assert_invalid_trace(self.cx.execute, "select * from test")
-
- # TODO(picnixz): increase test coverage for other callbacks
- # such as 'func', 'step', 'finalize', and 'collation'.
-
-
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py
index 2b907e35131d06..00d3bce5b7ee22 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -24,11 +24,15 @@
import sqlite3 as sqlite
import unittest
+from test.support import import_helper
from test.support.os_helper import TESTFN, unlink
from .util import memory_database, cx_limit, with_tracebacks
from .util import MemoryDatabaseMixin
+# TODO(picnixz): increase test coverage for other callbacks
+# such as 'func', 'step', 'finalize', and 'collation'.
+
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
@@ -129,8 +133,59 @@ def test_deregister_collation(self):
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
+class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase):
+
+ def assert_not_authorized(self, func, /, *args, **kwargs):
+ with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
+ func(*args, **kwargs)
+
+ # When a handler has an invalid signature, the exception raised is
+ # the same that would be raised if the handler "negatively" replied.
+
+ def test_authorizer_invalid_signature(self):
+ self.cx.execute("create table if not exists test(a number)")
+ self.cx.set_authorizer(lambda: None)
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+ # Tests for checking that callback context mutations do not crash.
+ # Regression tests for https://github.com/python/cpython/issues/142830.
+
+ @with_tracebacks(ZeroDivisionError, regex="hello world")
+ def test_authorizer_concurrent_mutation_in_call(self):
+ self.cx.execute("create table if not exists test(a number)")
+
+ class Handler:
+ cx = self.cx
+ def __call__(self, *a, **kw):
+ self.cx.set_authorizer(None)
+ raise ZeroDivisionError("hello world")
+
+ self.cx.set_authorizer(Handler())
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+ @with_tracebacks(OverflowError)
+ def test_authorizer_concurrent_mutation_with_overflown_value(self):
+ _testcapi = import_helper.import_module("_testcapi")
+ self.cx.execute("create table if not exists test(a number)")
+
+ class Handler:
+ cx = self.cx
+ def __call__(self, *a, **kw):
+ self.cx.set_authorizer(None)
+ # We expect 'int' at the C level, so this one will raise
+ # when converting via PyLong_Int().
+ return _testcapi.INT_MAX + 1
+
+ self.cx.set_authorizer(Handler())
+ self.assert_not_authorized(self.cx.execute, "select * from test")
+
+
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
+ def assert_interrupted(self, func, /, *args, **kwargs):
+ with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
+ func(*args, **kwargs)
+
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
@@ -219,11 +274,51 @@ def bad_progress():
create table foo(a, b)
""")
- def test_progress_handler_keyword_args(self):
+ def test_set_progress_handler_keyword_args(self):
with self.assertRaisesRegex(TypeError,
'takes at least 1 positional argument'):
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
+ # When a handler has an invalid signature, the exception raised is
+ # the same that would be raised if the handler "negatively" replied.
+
+ def test_progress_handler_invalid_signature(self):
+ self.cx.execute("create table if not exists test(a number)")
+ self.cx.set_progress_handler(lambda x: None, 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ # Tests for checking that callback context mutations do not crash.
+ # Regression tests for https://github.com/python/cpython/issues/142830.
+
+ @with_tracebacks(ZeroDivisionError, regex="hello world")
+ def test_progress_handler_concurrent_mutation_in_call(self):
+ self.cx.execute("create table if not exists test(a number)")
+
+ class Handler:
+ cx = self.cx
+ def __call__(self, *a, **kw):
+ self.cx.set_progress_handler(None, 1)
+ raise ZeroDivisionError("hello world")
+
+ self.cx.set_progress_handler(Handler(), 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ def test_progress_handler_concurrent_mutation_in_conversion(self):
+ self.cx.execute("create table if not exists test(a number)")
+
+ class Handler:
+ cx = self.cx
+ def __bool__(self):
+ # clear the progress handler
+ self.cx.set_progress_handler(None, 1)
+ raise ValueError # force PyObject_True() to fail
+
+ self.cx.set_progress_handler(Handler.__init__, 1)
+ self.assert_interrupted(self.cx.execute, "select * from test")
+
+ # Running with tracebacks makes the second execution of this
+ # function raise another exception because of a database change.
+
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@@ -345,11 +440,42 @@ def test_trace_bad_handler(self):
cx.set_trace_callback(lambda stmt: 5/0)
cx.execute("select 1")
- def test_trace_keyword_args(self):
+ def test_set_trace_callback_keyword_args(self):
with self.assertRaisesRegex(TypeError,
'takes exactly 1 positional argument'):
self.con.set_trace_callback(trace_callback=lambda: None)
+ # When a handler has an invalid signature, the exception raised is
+ # the same that would be raised if the handler "negatively" replied,
+ # but for the trace handler, exceptions are never re-raised (only
+ # printed when needed).
+
+ @with_tracebacks(
+ TypeError,
+ regex=r".*\(\) missing 6 required positional arguments",
+ )
+ def test_trace_handler_invalid_signature(self):
+ self.cx.execute("create table if not exists test(a number)")
+ self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None)
+ self.cx.execute("select * from test")
+
+ # Tests for checking that callback context mutations do not crash.
+ # Regression tests for https://github.com/python/cpython/issues/142830.
+
+ @with_tracebacks(ZeroDivisionError, regex="hello world")
+ def test_trace_callback_concurrent_mutation_in_call(self):
+ self.cx.execute("create table if not exists test(a number)")
+
+ class Handler:
+ cx = self.cx
+ def __call__(self, statement):
+ # clear the progress handler
+ self.cx.set_trace_callback(None)
+ raise ZeroDivisionError("hello world")
+
+ self.cx.set_trace_callback(Handler())
+ self.cx.execute("select * from test")
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py
index cccd062160fd40..cd88d7579671d0 100644
--- a/Lib/test/test_sqlite3/util.py
+++ b/Lib/test/test_sqlite3/util.py
@@ -50,7 +50,7 @@ def check_tracebacks(self, cm, exc, exc_regex, msg_regex, obj_name):
with contextlib.redirect_stderr(buf):
yield
- self.assertEqual(cm.unraisable.exc_type, exc)
+ self.assertIsSubclass(cm.unraisable.exc_type, exc)
if exc_regex:
msg = str(cm.unraisable.exc_value)
self.assertIsNotNone(exc_regex.search(msg), (exc_regex, msg))
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 26e2cddbef9d6b..d8527a2a79d32e 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -1450,6 +1450,8 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
assert(ctx_vp != NULL);
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
+ // Hold a reference to 'ctx' to prevent concurrent mutations.
+ Py_INCREF(ctx);
pysqlite_state *state = ctx->state;
assert(state != NULL);
@@ -1474,9 +1476,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
sqlite3_free((void *)expanded_sql);
}
if (py_statement) {
- Py_INCREF(ctx);
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
- Py_DECREF(ctx);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
@@ -1485,6 +1485,7 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
}
exit:
+ Py_DECREF(ctx);
PyGILState_Release(gilstate);
return 0;
}
diff --git a/Modules/_sqlite/context.c b/Modules/_sqlite/context.c
index 01077eb33215aa..48f6886d11137f 100644
--- a/Modules/_sqlite/context.c
+++ b/Modules/_sqlite/context.c
@@ -49,7 +49,6 @@ callback_context_new_impl(PyTypeObject *type, PyObject *callable)
ctx->state = state;
PyObject_GC_Track(ctx);
return (PyObject *)ctx;
-
}
static int
@@ -104,8 +103,7 @@ static PyType_Spec callback_context_spec = {
PyObject *
pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable)
{
- PyTypeObject *type = state->CallbackContextType;
- return callback_context_new_impl(type, callable);
+ return callback_context_new_impl(state->CallbackContextType, callable);
}
int
diff --git a/Modules/_sqlite/context.h b/Modules/_sqlite/context.h
index 10f893ad50e382..9775b2cef46502 100644
--- a/Modules/_sqlite/context.h
+++ b/Modules/_sqlite/context.h
@@ -9,6 +9,12 @@ extern "C" {
#include "module.h"
+/*
+ * UDF/callback context structure.
+ *
+ * In order to ensure that the state pointer always outlives the callback
+ * context, we make sure it owns a reference to the module itself.
+ */
typedef struct {
PyObject_HEAD
PyObject *callable;
@@ -18,12 +24,6 @@ typedef struct {
#define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op))
-/* Allocate a UDF/callback context structure. In order to ensure that the state
- * pointer always outlives the callback context, we make sure it owns a
- * reference to the module itself. create_callback_context() is always called
- * from connection methods, so we use the defining class to fetch the module
- * pointer.
- */
PyObject *
pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable);
From 80058035011c3edaaa9a49cd81d453d57e1f326a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?=
<10796600+picnixz@users.noreply.github.com>
Date: Mon, 29 Dec 2025 20:44:32 +0100
Subject: [PATCH 4/4] simplify tests
---
Lib/test/test_sqlite3/test_hooks.py | 51 ++++++++++++-----------------
1 file changed, 21 insertions(+), 30 deletions(-)
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py
index 00d3bce5b7ee22..495ef97fa3c61c 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -154,13 +154,11 @@ def test_authorizer_invalid_signature(self):
def test_authorizer_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
- class Handler:
- cx = self.cx
- def __call__(self, *a, **kw):
- self.cx.set_authorizer(None)
- raise ZeroDivisionError("hello world")
+ def handler(*a, **kw):
+ self.cx.set_authorizer(None)
+ raise ZeroDivisionError("hello world")
- self.cx.set_authorizer(Handler())
+ self.cx.set_authorizer(handler)
self.assert_not_authorized(self.cx.execute, "select * from test")
@with_tracebacks(OverflowError)
@@ -168,15 +166,13 @@ def test_authorizer_concurrent_mutation_with_overflown_value(self):
_testcapi = import_helper.import_module("_testcapi")
self.cx.execute("create table if not exists test(a number)")
- class Handler:
- cx = self.cx
- def __call__(self, *a, **kw):
- self.cx.set_authorizer(None)
- # We expect 'int' at the C level, so this one will raise
- # when converting via PyLong_Int().
- return _testcapi.INT_MAX + 1
-
- self.cx.set_authorizer(Handler())
+ def handler(*a, **kw):
+ self.cx.set_authorizer(None)
+ # We expect 'int' at the C level, so this one will raise
+ # when converting via PyLong_Int().
+ return _testcapi.INT_MAX + 1
+
+ self.cx.set_authorizer(handler)
self.assert_not_authorized(self.cx.execute, "select * from test")
@@ -294,21 +290,18 @@ def test_progress_handler_invalid_signature(self):
def test_progress_handler_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
- class Handler:
- cx = self.cx
- def __call__(self, *a, **kw):
- self.cx.set_progress_handler(None, 1)
- raise ZeroDivisionError("hello world")
+ def handler(*a, **kw):
+ self.cx.set_progress_handler(None, 1)
+ raise ZeroDivisionError("hello world")
- self.cx.set_progress_handler(Handler(), 1)
+ self.cx.set_progress_handler(handler, 1)
self.assert_interrupted(self.cx.execute, "select * from test")
def test_progress_handler_concurrent_mutation_in_conversion(self):
self.cx.execute("create table if not exists test(a number)")
class Handler:
- cx = self.cx
- def __bool__(self):
+ def __bool__(_):
# clear the progress handler
self.cx.set_progress_handler(None, 1)
raise ValueError # force PyObject_True() to fail
@@ -466,14 +459,12 @@ def test_trace_handler_invalid_signature(self):
def test_trace_callback_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
- class Handler:
- cx = self.cx
- def __call__(self, statement):
- # clear the progress handler
- self.cx.set_trace_callback(None)
- raise ZeroDivisionError("hello world")
+ def handler(statement):
+ # clear the progress handler
+ self.cx.set_trace_callback(None)
+ raise ZeroDivisionError("hello world")
- self.cx.set_trace_callback(Handler())
+ self.cx.set_trace_callback(handler)
self.cx.execute("select * from test")