diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 2b907e35131d06..495ef97fa3c61c 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -24,11 +24,15 @@ import sqlite3 as sqlite import unittest +from test.support import import_helper from test.support.os_helper import TESTFN, unlink from .util import memory_database, cx_limit, with_tracebacks from .util import MemoryDatabaseMixin +# TODO(picnixz): increase test coverage for other callbacks +# such as 'func', 'step', 'finalize', and 'collation'. + class CollationTests(MemoryDatabaseMixin, unittest.TestCase): @@ -129,8 +133,55 @@ def test_deregister_collation(self): self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') +class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase): + + def assert_not_authorized(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"): + func(*args, **kwargs) + + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_authorizer_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_authorizer(lambda: None) + self.assert_not_authorized(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_authorizer_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_authorizer(None) + raise ZeroDivisionError("hello world") + + self.cx.set_authorizer(handler) + self.assert_not_authorized(self.cx.execute, "select * from test") + + @with_tracebacks(OverflowError) + def test_authorizer_concurrent_mutation_with_overflown_value(self): + _testcapi = import_helper.import_module("_testcapi") + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_authorizer(None) + # We expect 'int' at the C level, so this one will raise + # when converting via PyLong_Int(). + return _testcapi.INT_MAX + 1 + + self.cx.set_authorizer(handler) + self.assert_not_authorized(self.cx.execute, "select * from test") + + class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def assert_interrupted(self, func, /, *args, **kwargs): + with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"): + func(*args, **kwargs) + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. @@ -219,11 +270,48 @@ def bad_progress(): create table foo(a, b) """) - def test_progress_handler_keyword_args(self): + def test_set_progress_handler_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes at least 1 positional argument'): self.con.set_progress_handler(progress_handler=lambda: None, n=1) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied. + + def test_progress_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_progress_handler(lambda x: None, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_progress_handler_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(*a, **kw): + self.cx.set_progress_handler(None, 1) + raise ZeroDivisionError("hello world") + + self.cx.set_progress_handler(handler, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + def test_progress_handler_concurrent_mutation_in_conversion(self): + self.cx.execute("create table if not exists test(a number)") + + class Handler: + def __bool__(_): + # clear the progress handler + self.cx.set_progress_handler(None, 1) + raise ValueError # force PyObject_True() to fail + + self.cx.set_progress_handler(Handler.__init__, 1) + self.assert_interrupted(self.cx.execute, "select * from test") + + # Running with tracebacks makes the second execution of this + # function raise another exception because of a database change. + class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): @@ -345,11 +433,40 @@ def test_trace_bad_handler(self): cx.set_trace_callback(lambda stmt: 5/0) cx.execute("select 1") - def test_trace_keyword_args(self): + def test_set_trace_callback_keyword_args(self): with self.assertRaisesRegex(TypeError, 'takes exactly 1 positional argument'): self.con.set_trace_callback(trace_callback=lambda: None) + # When a handler has an invalid signature, the exception raised is + # the same that would be raised if the handler "negatively" replied, + # but for the trace handler, exceptions are never re-raised (only + # printed when needed). + + @with_tracebacks( + TypeError, + regex=r".*\(\) missing 6 required positional arguments", + ) + def test_trace_handler_invalid_signature(self): + self.cx.execute("create table if not exists test(a number)") + self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None) + self.cx.execute("select * from test") + + # Tests for checking that callback context mutations do not crash. + # Regression tests for https://github.com/python/cpython/issues/142830. + + @with_tracebacks(ZeroDivisionError, regex="hello world") + def test_trace_callback_concurrent_mutation_in_call(self): + self.cx.execute("create table if not exists test(a number)") + + def handler(statement): + # clear the progress handler + self.cx.set_trace_callback(None) + raise ZeroDivisionError("hello world") + + self.cx.set_trace_callback(handler) + self.cx.execute("select * from test") + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py index cccd062160fd40..cd88d7579671d0 100644 --- a/Lib/test/test_sqlite3/util.py +++ b/Lib/test/test_sqlite3/util.py @@ -50,7 +50,7 @@ def check_tracebacks(self, cm, exc, exc_regex, msg_regex, obj_name): with contextlib.redirect_stderr(buf): yield - self.assertEqual(cm.unraisable.exc_type, exc) + self.assertIsSubclass(cm.unraisable.exc_type, exc) if exc_regex: msg = str(cm.unraisable.exc_value) self.assertIsNotNone(exc_regex.search(msg), (exc_regex, msg)) diff --git a/Makefile.pre.in b/Makefile.pre.in index a6beb96d12a3f2..36c81b0337fddd 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -3429,7 +3429,7 @@ MODULE__SSL_DEPS=$(srcdir)/Modules/_ssl.h $(srcdir)/Modules/_ssl/cert.c $(srcdir MODULE__TESTCAPI_DEPS=$(srcdir)/Modules/_testcapi/parts.h $(srcdir)/Modules/_testcapi/util.h MODULE__TESTLIMITEDCAPI_DEPS=$(srcdir)/Modules/_testlimitedcapi/testcapi_long.h $(srcdir)/Modules/_testlimitedcapi/parts.h $(srcdir)/Modules/_testlimitedcapi/util.h MODULE__TESTINTERNALCAPI_DEPS=$(srcdir)/Modules/_testinternalcapi/parts.h -MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h +MODULE__SQLITE3_DEPS=$(srcdir)/Modules/_sqlite/connection.h $(srcdir)/Modules/_sqlite/context.h $(srcdir)/Modules/_sqlite/cursor.h $(srcdir)/Modules/_sqlite/microprotocols.h $(srcdir)/Modules/_sqlite/module.h $(srcdir)/Modules/_sqlite/prepare_protocol.h $(srcdir)/Modules/_sqlite/row.h $(srcdir)/Modules/_sqlite/util.h MODULE__ZSTD_DEPS=$(srcdir)/Modules/_zstd/_zstdmodule.h $(srcdir)/Modules/_zstd/buffer.h $(srcdir)/Modules/_zstd/zstddict.h CODECS_COMMON_HEADERS=$(srcdir)/Modules/cjkcodecs/multibytecodec.h $(srcdir)/Modules/cjkcodecs/cjkcodecs.h diff --git a/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst new file mode 100644 index 00000000000000..246979e91d76b5 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-12-28-13-12-40.gh-issue-142830.uEyd6r.rst @@ -0,0 +1,2 @@ +:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks +are mutated during a callback execution. Patch by Bénédikt Tran. diff --git a/Modules/Setup.stdlib.in b/Modules/Setup.stdlib.in index 2a4b937ce6bf80..59d649825fbe9d 100644 --- a/Modules/Setup.stdlib.in +++ b/Modules/Setup.stdlib.in @@ -149,7 +149,7 @@ # needs -lncurses[w] and -lpanel[w] @MODULE__CURSES_PANEL_TRUE@_curses_panel _curses_panel.c -@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c +@MODULE__SQLITE3_TRUE@_sqlite3 _sqlite/blob.c _sqlite/connection.c _sqlite/context.c _sqlite/cursor.c _sqlite/microprotocols.c _sqlite/module.c _sqlite/prepare_protocol.c _sqlite/row.c _sqlite/statement.c _sqlite/util.c # needs -lssl and -lcrypt @MODULE__SSL_TRUE@_ssl _ssl.c diff --git a/Modules/_sqlite/clinic/context.c.h b/Modules/_sqlite/clinic/context.c.h new file mode 100644 index 00000000000000..9788f819ac7f68 --- /dev/null +++ b/Modules/_sqlite/clinic/context.c.h @@ -0,0 +1,30 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +#include "pycore_modsupport.h" // _PyArg_CheckPositional() + +static PyObject * +callback_context_new_impl(PyTypeObject *type, PyObject *callable); + +static PyObject * +callback_context_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) +{ + PyObject *return_value = NULL; + PyTypeObject *base_tp = clinic_state()->CallbackContextType; + PyObject *callable; + + if ((type == base_tp || type->tp_init == base_tp->tp_init) && + !_PyArg_NoKeywords("_CallbackContext", kwargs)) { + goto exit; + } + if (!_PyArg_CheckPositional("_CallbackContext", PyTuple_GET_SIZE(args), 1, 1)) { + goto exit; + } + callable = PyTuple_GET_ITEM(args, 0); + return_value = callback_context_new_impl(type, callable); + +exit: + return return_value; +} +/*[clinic end generated code: output=370246c27daeaaa1 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 83ff8e60557c07..d8527a2a79d32e 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -145,9 +145,6 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp /*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/ static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self); -static void free_callback_context(callback_context *ctx); -static void set_callback_context(callback_context **ctx_pp, - callback_context *ctx); static int connection_close(pysqlite_Connection *self); PyObject *_pysqlite_query_execute(pysqlite_Cursor *, int, PyObject *, PyObject *); @@ -377,13 +374,13 @@ output pop [clinic start generated code]*/ /*[clinic end generated code: output=da39a3ee5e6b4b0d input=b899ba9273edcce7]*/ -#define VISIT_CALLBACK_CONTEXT(ctx) \ -do { \ - if (ctx) { \ - Py_VISIT(ctx->callable); \ - Py_VISIT(ctx->module); \ - } \ -} while (0) +static void +clear_callback_contexts(pysqlite_Connection *self) +{ + Py_CLEAR(self->trace_ctx); + Py_CLEAR(self->progress_ctx); + Py_CLEAR(self->authorizer_ctx); +} static int connection_traverse(PyObject *op, visitproc visit, void *arg) @@ -395,22 +392,12 @@ connection_traverse(PyObject *op, visitproc visit, void *arg) Py_VISIT(self->blobs); Py_VISIT(self->row_factory); Py_VISIT(self->text_factory); - VISIT_CALLBACK_CONTEXT(self->trace_ctx); - VISIT_CALLBACK_CONTEXT(self->progress_ctx); - VISIT_CALLBACK_CONTEXT(self->authorizer_ctx); -#undef VISIT_CALLBACK_CONTEXT + Py_VISIT(self->trace_ctx); + Py_VISIT(self->progress_ctx); + Py_VISIT(self->authorizer_ctx); return 0; } -static inline void -clear_callback_context(callback_context *ctx) -{ - if (ctx != NULL) { - Py_CLEAR(ctx->callable); - Py_CLEAR(ctx->module); - } -} - static int connection_clear(PyObject *op) { @@ -420,20 +407,10 @@ connection_clear(PyObject *op) Py_CLEAR(self->blobs); Py_CLEAR(self->row_factory); Py_CLEAR(self->text_factory); - clear_callback_context(self->trace_ctx); - clear_callback_context(self->progress_ctx); - clear_callback_context(self->authorizer_ctx); + clear_callback_contexts(self); return 0; } -static void -free_callback_contexts(pysqlite_Connection *self) -{ - set_callback_context(&self->trace_ctx, NULL); - set_callback_context(&self->progress_ctx, NULL); - set_callback_context(&self->authorizer_ctx, NULL); -} - static void remove_callbacks(sqlite3 *db) { @@ -474,7 +451,7 @@ connection_close(pysqlite_Connection *self) (void)sqlite3_close_v2(db); Py_END_ALLOW_THREADS - free_callback_contexts(self); + clear_callback_contexts(self); return rc; } @@ -814,7 +791,7 @@ _pysqlite_set_result(sqlite3_context* context, PyObject* py_val) sqlite3_result_blob(context, view.buf, (int)view.len, SQLITE_TRANSIENT); PyBuffer_Release(&view); } else { - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); PyErr_Format(ctx->state->ProgrammingError, "User-defined functions cannot return '%s' values to " "SQLite", @@ -893,7 +870,7 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc, } static void -print_or_clear_traceback(callback_context *ctx) +print_or_clear_traceback(pysqlite_CallbackContext *ctx) { assert(ctx != NULL); assert(ctx->state != NULL); @@ -920,7 +897,7 @@ set_sqlite_error(sqlite3_context *context, const char *msg) else { sqlite3_result_error(context, msg, -1); } - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); print_or_clear_traceback(ctx); } @@ -935,9 +912,11 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) args = _pysqlite_build_py_params(context, argc, argv); if (args) { - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + Py_INCREF(ctx); py_retval = PyObject_CallObject(ctx->callable, args); + Py_DECREF(ctx); Py_DECREF(args); } @@ -963,8 +942,10 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) PyObject** aggregate_instance; PyObject* stepmethod = NULL; - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (aggregate_instance == NULL) { @@ -994,6 +975,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } function_result = PyObject_CallObject(stepmethod, args); + Py_CLEAR(ctx); Py_DECREF(args); if (!function_result) { @@ -1002,6 +984,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } error: + Py_XDECREF(ctx); Py_XDECREF(stepmethod); Py_XDECREF(function_result); @@ -1032,10 +1015,12 @@ final_callback(sqlite3_context *context) // Keep the exception (if any) of the last call to step, value, or inverse PyObject *exc = PyErr_GetRaisedException(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + Py_INCREF(ctx); function_result = PyObject_CallMethodNoArgs(*aggregate_instance, ctx->state->str_finalize); + Py_DECREF(ctx); Py_DECREF(*aggregate_instance); ok = 0; @@ -1095,55 +1080,14 @@ _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self) return 0; } -/* Allocate a UDF/callback context structure. In order to ensure that the state - * pointer always outlives the callback context, we make sure it owns a - * reference to the module itself. create_callback_context() is always called - * from connection methods, so we use the defining class to fetch the module - * pointer. - */ -static callback_context * -create_callback_context(PyTypeObject *cls, PyObject *callable) -{ - callback_context *ctx = PyMem_Malloc(sizeof(callback_context)); - if (ctx != NULL) { - PyObject *module = PyType_GetModule(cls); - ctx->callable = Py_NewRef(callable); - ctx->module = Py_NewRef(module); - ctx->state = pysqlite_get_state(module); - } - return ctx; -} - -static void -free_callback_context(callback_context *ctx) -{ - assert(ctx != NULL); - Py_XDECREF(ctx->callable); - Py_XDECREF(ctx->module); - PyMem_Free(ctx); -} - -static void -set_callback_context(callback_context **ctx_pp, callback_context *ctx) -{ - assert(ctx_pp != NULL); - callback_context *tmp = *ctx_pp; - *ctx_pp = ctx; - if (tmp != NULL) { - free_callback_context(tmp); - } -} - static void destructor_callback(void *ctx) { - if (ctx != NULL) { - // This function may be called without the GIL held, so we need to - // ensure that we destroy 'ctx' with the GIL held. - PyGILState_STATE gstate = PyGILState_Ensure(); - free_callback_context((callback_context *)ctx); - PyGILState_Release(gstate); - } + // This function may be called without the GIL held, so we need to + // ensure that we destroy 'ctx' with the GIL held. + PyGILState_STATE gstate = PyGILState_Ensure(); + Py_DECREF(ctx); + PyGILState_Release(gstate); } static int @@ -1194,7 +1138,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, if (deterministic) { flags |= SQLITE_DETERMINISTIC; } - callback_context *ctx = create_callback_context(cls, func); + PyObject *ctx = pysqlite_create_callback_context(self->state, func); if (ctx == NULL) { return NULL; } @@ -1225,8 +1169,10 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) { PyGILState_STATE gilstate = PyGILState_Ensure(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); int size = sizeof(PyObject *); PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); @@ -1255,9 +1201,11 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) "user-defined aggregate's 'inverse' method raised error"); goto exit; } + Py_CLEAR(ctx); Py_DECREF(res); exit: + Py_XDECREF(ctx); Py_XDECREF(method); PyGILState_Release(gilstate); } @@ -1273,7 +1221,7 @@ value_callback(sqlite3_context *context) { PyGILState_STATE gilstate = PyGILState_Ensure(); - callback_context *ctx = (callback_context *)sqlite3_user_data(context); + pysqlite_CallbackContext *ctx = sqlite3_user_data(context); assert(ctx != NULL); int size = sizeof(PyObject *); @@ -1281,7 +1229,10 @@ value_callback(sqlite3_context *context) assert(cls != NULL); assert(*cls != NULL); + Py_INCREF(ctx); PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value); + Py_DECREF(ctx); + if (res == NULL) { int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); set_sqlite_error(context, attr_err @@ -1344,7 +1295,7 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, 0, 0, 0, 0, 0, 0); } else { - callback_context *ctx = create_callback_context(cls, aggregate_class); + PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class); if (ctx == NULL) { return NULL; } @@ -1395,7 +1346,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, return NULL; } - callback_context *ctx = create_callback_context(cls, aggregate_class); + PyObject *ctx = pysqlite_create_callback_context(self->state, aggregate_class); if (ctx == NULL) { return NULL; } @@ -1413,7 +1364,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, } static int -authorizer_callback(void *ctx, int action, const char *arg1, +authorizer_callback(void *ctx_vp, int action, const char *arg1, const char *arg2 , const char *dbname, const char *access_attempt_source) { @@ -1422,11 +1373,13 @@ authorizer_callback(void *ctx, int action, const char *arg1, PyObject *ret; int rc = SQLITE_DENY; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; - ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname, - access_attempt_source); + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + ret = PyObject_CallFunction(ctx->callable, "issss", action, arg1, arg2, + dbname, access_attempt_source); if (ret == NULL) { print_or_clear_traceback(ctx); rc = SQLITE_DENY; @@ -1444,22 +1397,26 @@ authorizer_callback(void *ctx, int action, const char *arg1, } Py_DECREF(ret); } + Py_DECREF(ctx); PyGILState_Release(gilstate); return rc; } static int -progress_callback(void *ctx) +progress_callback(void *ctx_vp) { PyGILState_STATE gilstate = PyGILState_Ensure(); int rc; PyObject *ret; - assert(ctx != NULL); - PyObject *callable = ((callback_context *)ctx)->callable; - ret = PyObject_CallNoArgs(callable); + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + + ret = PyObject_CallNoArgs(ctx->callable); if (!ret) { /* abort query if error occurred */ rc = -1; @@ -1471,7 +1428,7 @@ progress_callback(void *ctx) if (rc < 0) { print_or_clear_traceback(ctx); } - + Py_DECREF(ctx); PyGILState_Release(gilstate); return rc; } @@ -1483,7 +1440,7 @@ progress_callback(void *ctx) * to ensure future compatibility. */ static int -trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) +trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql) { if (type != SQLITE_TRACE_STMT) { return 0; @@ -1491,8 +1448,11 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyGILState_STATE gilstate = PyGILState_Ensure(); - assert(ctx != NULL); - pysqlite_state *state = ((callback_context *)ctx)->state; + assert(ctx_vp != NULL); + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + pysqlite_state *state = ctx->state; assert(state != NULL); PyObject *py_statement = NULL; @@ -1506,7 +1466,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) PyErr_SetString(state->DataError, "Expanded SQL string exceeds the maximum string length"); - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); // Fall back to unexpanded sql py_statement = PyUnicode_FromString((const char *)sql); @@ -1516,16 +1476,16 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) sqlite3_free((void *)expanded_sql); } if (py_statement) { - PyObject *callable = ((callback_context *)ctx)->callable; - PyObject *ret = PyObject_CallOneArg(callable, py_statement); + PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement); Py_DECREF(py_statement); Py_XDECREF(ret); } if (PyErr_Occurred()) { - print_or_clear_traceback((callback_context *)ctx); + print_or_clear_traceback(ctx); } exit: + Py_DECREF(ctx); PyGILState_Release(gilstate); return 0; } @@ -1552,21 +1512,22 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self, int rc; if (callable == Py_None) { + /* None clears the authorizer callback previously set */ rc = sqlite3_set_authorizer(self->db, NULL, NULL); - set_callback_context(&self->authorizer_ctx, NULL); + Py_CLEAR(self->authorizer_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } rc = sqlite3_set_authorizer(self->db, authorizer_callback, ctx); - set_callback_context(&self->authorizer_ctx, ctx); + Py_XSETREF(self->authorizer_ctx, ctx); } if (rc != SQLITE_OK) { PyErr_SetString(self->OperationalError, "Error setting authorizer callback"); - set_callback_context(&self->authorizer_ctx, NULL); + Py_CLEAR(self->authorizer_ctx); return NULL; } Py_RETURN_NONE; @@ -1604,15 +1565,15 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self, if (callable == Py_None) { /* None clears the progress handler previously set */ sqlite3_progress_handler(self->db, 0, 0, (void*)0); - set_callback_context(&self->progress_ctx, NULL); + Py_CLEAR(self->progress_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } sqlite3_progress_handler(self->db, n, progress_callback, ctx); - set_callback_context(&self->progress_ctx, ctx); + Py_XSETREF(self->progress_ctx, ctx); } Py_RETURN_NONE; } @@ -1637,6 +1598,7 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self, return NULL; } + int rc; if (callable == Py_None) { /* * None clears the trace callback previously set @@ -1645,18 +1607,22 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self, * - https://sqlite.org/c3ref/c_trace.html * - https://sqlite.org/c3ref/trace_v2.html */ - sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0); - set_callback_context(&self->trace_ctx, NULL); + rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, 0, 0); + Py_CLEAR(self->trace_ctx); } else { - callback_context *ctx = create_callback_context(cls, callable); + PyObject *ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } - sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx); - set_callback_context(&self->trace_ctx, ctx); + rc = sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx); + Py_XSETREF(self->trace_ctx, ctx); + } + if (rc != SQLITE_OK) { + PyErr_SetString(self->OperationalError, "Error setting trace callback"); + Py_CLEAR(self->trace_ctx); + return NULL; } - Py_RETURN_NONE; } @@ -1945,6 +1911,7 @@ collation_callback(void *context, int text1_length, const void *text1_data, { PyGILState_STATE gilstate = PyGILState_Ensure(); + pysqlite_CallbackContext *ctx = NULL; PyObject* string1 = 0; PyObject* string2 = 0; PyObject* retval = NULL; @@ -1966,8 +1933,11 @@ collation_callback(void *context, int text1_length, const void *text1_data, goto finally; } - callback_context *ctx = (callback_context *)context; + ctx = pysqlite_CallbackContext_CAST(context); assert(ctx != NULL); + // Hold a reference to 'ctx' to prevent concurrent mutations. + Py_INCREF(ctx); + PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs. size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET; retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL); @@ -1987,8 +1957,10 @@ collation_callback(void *context, int text1_length, const void *text1_data, else if (longval < 0) result = -1; } + Py_CLEAR(ctx); finally: + Py_XDECREF(ctx); Py_XDECREF(string1); Py_XDECREF(string2); Py_XDECREF(retval); @@ -2185,7 +2157,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, return NULL; } - callback_context *ctx = NULL; + PyObject *ctx = NULL; int rc; int flags = SQLITE_UTF8; if (callable == Py_None) { @@ -2193,11 +2165,13 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, NULL, NULL, NULL); } else { + // TODO(picnixz): defer this check to the context's constructor + // and do it for all other functions that create a context. if (!PyCallable_Check(callable)) { PyErr_SetString(PyExc_TypeError, "parameter must be callable"); return NULL; } - ctx = create_callback_context(cls, callable); + ctx = pysqlite_create_callback_context(self->state, callable); if (ctx == NULL) { return NULL; } @@ -2211,9 +2185,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, * called if sqlite3_create_collation_v2() fails, so we have to free * the context before returning. */ - if (callable != Py_None) { - free_callback_context(ctx); - } + Py_XDECREF(ctx); set_error_from_db(self->state, self->db); return NULL; } diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 7a748ee3ea0c58..20c20d772befa3 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -23,21 +23,16 @@ #ifndef PYSQLITE_CONNECTION_H #define PYSQLITE_CONNECTION_H + #include "Python.h" #include "pythread.h" #include "structmember.h" +#include "context.h" #include "module.h" #include "sqlite3.h" -typedef struct _callback_context -{ - PyObject *callable; - PyObject *module; - pysqlite_state *state; -} callback_context; - enum autocommit_mode { AUTOCOMMIT_LEGACY = LEGACY_TRANSACTION_CONTROL, AUTOCOMMIT_ENABLED = 1, @@ -88,9 +83,9 @@ typedef struct PyObject* text_factory; // Remember contexts used by the trace, progress, and authoriser callbacks - callback_context *trace_ctx; - callback_context *progress_ctx; - callback_context *authorizer_ctx; + PyObject *trace_ctx; + PyObject *progress_ctx; + PyObject *authorizer_ctx; /* Exception objects: borrowed refs. */ PyObject* Warning; diff --git a/Modules/_sqlite/context.c b/Modules/_sqlite/context.c new file mode 100644 index 00000000000000..48f6886d11137f --- /dev/null +++ b/Modules/_sqlite/context.c @@ -0,0 +1,119 @@ +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + +#include "context.h" +#include "module.h" + +#define clinic_state() (pysqlite_get_state_by_type(type)) +#include "clinic/context.c.h" +#undef clinic_state + +/*[clinic input] +module _sqlite3 +class _sqlite3._CallbackContext "pysqlite_CallbackContext *" "clinic_state()->CallbackContextType" +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=f06c18b7c3666e05]*/ + +/*[clinic input] +@classmethod +_sqlite3._CallbackContext.__new__ as callback_context_new + + callable: object + / + +[clinic start generated code]*/ + +static PyObject * +callback_context_new_impl(PyTypeObject *type, PyObject *callable) +/*[clinic end generated code: output=46b4f355475d88cc input=ae33656b48f65a6d]*/ +{ + PyObject *module = PyType_GetModuleByDef(type, &_sqlite3module); + assert(module != NULL); + pysqlite_state *state = pysqlite_get_state_by_type(type); + assert(state != NULL); + + pysqlite_CallbackContext *ctx = PyObject_GC_New( + pysqlite_CallbackContext, + state->CallbackContextType + ); + if (ctx == NULL) { + return NULL; + } + + // TODO(picnixz): check that 'callable' is effectively callable + // instead of relying on a generic TypeError when attempting to + // call it. + ctx->callable = Py_NewRef(callable); + ctx->module = Py_NewRef(module); + ctx->state = state; + PyObject_GC_Track(ctx); + return (PyObject *)ctx; +} + +static int +callback_context_clear(PyObject *op) +{ + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op); + Py_CLEAR(ctx->callable); + Py_CLEAR(ctx->module); + return 0; +} + +static void +callback_context_dealloc(PyObject *op) +{ + PyTypeObject *type = Py_TYPE(op); + PyObject_GC_UnTrack(op); + (void)type->tp_clear(op); + type->tp_free(op); + Py_DECREF(type); +} + +static int +callback_context_traverse(PyObject *op, visitproc visit, void *arg) +{ + pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(op); + Py_VISIT(Py_TYPE(op)); + Py_VISIT(ctx->callable); + Py_VISIT(ctx->module); + return 0; +} + +static PyType_Slot callback_context_slots[] = { + {Py_tp_new, callback_context_new}, + {Py_tp_clear, callback_context_clear}, + {Py_tp_dealloc, callback_context_dealloc}, + {Py_tp_traverse, callback_context_traverse}, + {0, NULL}, +}; + +static PyType_Spec callback_context_spec = { + .name = MODULE_NAME "._CallbackContext", + .basicsize = sizeof(pysqlite_CallbackContext), + .flags = ( + Py_TPFLAGS_DEFAULT + | Py_TPFLAGS_DISALLOW_INSTANTIATION + | Py_TPFLAGS_IMMUTABLETYPE + | Py_TPFLAGS_HAVE_GC + ), + .slots = callback_context_slots, +}; + +PyObject * +pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable) +{ + return callback_context_new_impl(state->CallbackContextType, callable); +} + +int +pysqlite_context_setup_types(PyObject *module) +{ + PyObject *type = PyType_FromModuleAndSpec(module, &callback_context_spec, NULL); + if (type == NULL) { + return -1; + } + pysqlite_state *state = pysqlite_get_state(module); + state->CallbackContextType = (PyTypeObject *)type; + return 0; +} diff --git a/Modules/_sqlite/context.h b/Modules/_sqlite/context.h new file mode 100644 index 00000000000000..9775b2cef46502 --- /dev/null +++ b/Modules/_sqlite/context.h @@ -0,0 +1,37 @@ +#ifndef PYSQLITE_CALLBACK_CONTEXT_H +#define PYSQLITE_CALLBACK_CONTEXT_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "Python.h" + +#include "module.h" + +/* + * UDF/callback context structure. + * + * In order to ensure that the state pointer always outlives the callback + * context, we make sure it owns a reference to the module itself. + */ +typedef struct { + PyObject_HEAD + PyObject *callable; + PyObject *module; + pysqlite_state *state; +} pysqlite_CallbackContext; + +#define pysqlite_CallbackContext_CAST(op) ((pysqlite_CallbackContext *)(op)) + +PyObject * +pysqlite_create_callback_context(pysqlite_state *state, PyObject *callable); + +int +pysqlite_context_setup_types(PyObject *module); + +#ifdef __cplusplus +} +#endif + +#endif // !PYSQLITE_CALLBACK_CONTEXT_H diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 831dd9219f77ab..923268159a4197 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -686,6 +686,7 @@ module_exec(PyObject *module) if ((pysqlite_row_setup_types(module) < 0) || (pysqlite_cursor_setup_types(module) < 0) || + (pysqlite_context_setup_types(module) < 0) || (pysqlite_connection_setup_types(module) < 0) || (pysqlite_statement_setup_types(module) < 0) || (pysqlite_prepare_protocol_setup_types(module) < 0) || diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index a4ca45cf6326a9..4fee440146c7f6 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -55,6 +55,7 @@ typedef struct { int enable_callback_tracebacks; PyTypeObject *BlobType; + PyTypeObject *CallbackContextType; PyTypeObject *ConnectionType; PyTypeObject *CursorType; PyTypeObject *PrepareProtocolType; diff --git a/PCbuild/_sqlite3.vcxproj b/PCbuild/_sqlite3.vcxproj index 9ae0a0fc3a009d..4919b28293875e 100644 --- a/PCbuild/_sqlite3.vcxproj +++ b/PCbuild/_sqlite3.vcxproj @@ -99,6 +99,7 @@ + @@ -110,6 +111,7 @@ + @@ -135,4 +137,4 @@ - \ No newline at end of file + diff --git a/PCbuild/_sqlite3.vcxproj.filters b/PCbuild/_sqlite3.vcxproj.filters index f4a265eba7dd80..6a930a01d2e519 100644 --- a/PCbuild/_sqlite3.vcxproj.filters +++ b/PCbuild/_sqlite3.vcxproj.filters @@ -15,6 +15,9 @@ Header Files + + Header Files + Header Files @@ -44,6 +47,9 @@ Source Files + + Source Files + Source Files @@ -74,4 +80,4 @@ Resource Files - \ No newline at end of file +