Skip to content
Open
141 changes: 140 additions & 1 deletion Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import warnings

from test.support import (
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess
SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess, subTests
)
from test.support import gc_collect
from test.support import threading_helper, import_helper
Expand Down Expand Up @@ -728,6 +728,21 @@ def test_database_keyword(self):
self.assertEqual(type(cx), sqlite.Connection)


class ParamsCxCloseInIterMany:
def __init__(self, cx):
self.cx = cx

def __iter__(self):
self.cx.close()
return iter([(1,), (2,), (3,)])


def ParamsCxCloseInNext(cx):
for i in range(10):
cx.close()
yield (i,)


class CursorTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.connect(":memory:")
Expand Down Expand Up @@ -859,6 +874,31 @@ def __getitem__(slf, x):
with self.assertRaises(ZeroDivisionError):
self.cu.execute("select name from test where name=?", L())

def test_execute_use_after_close_with_bind_parameters(self):
# Prevent SIGSEGV when closing the connection while binding parameters.
#
# Internally, the connection's state is checked after bind_parameters().
# Without this check, we would only be aware of the closed connection
# by calling an sqlite3 function afterwards. However, it is important
# that we report the error before leaving the execute() call.
#
# Regression test for https://github.com/python/cpython/issues/143198.

class PT:
def __getitem__(self, i):
cx.close()
return 1
def __len__(self):
return 1

cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)
cu = cx.cursor()
msg = r"Cannot operate on a closed database\."
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cu.execute("insert into tmp(a) values (?)", PT())

def test_execute_named_param_and_sequence(self):
dataset = (
("select :a", (1,)),
Expand Down Expand Up @@ -1030,6 +1070,50 @@ def test_execute_many_not_iterable(self):
with self.assertRaises(TypeError):
self.cu.executemany("insert into test(income) values (?)", 42)

@subTests("params_class", (ParamsCxCloseInIterMany, ParamsCxCloseInNext))
def test_executemany_use_after_close(self, params_class):
# Prevent SIGSEGV with iterable of parameters closing the connection.
# Regression test for https://github.com/python/cpython/issues/143198.
cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)
cu = cx.cursor()
msg = r"Cannot operate on a closed database\."
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cu.executemany("insert into tmp(a) values (?)", params_class(cx))

@subTests(("j", "n"), ([0, 1], [0, 3], [1, 3], [2, 3]))
@subTests("wtype", (list, lambda x: x))
def test_executemany_use_after_close_with_bind_parameters(self, j, n, wtype):
# Prevent SIGSEGV when closing the connection while binding parameters.
#
# Internally, the connection's state is checked after bind_parameters().
# Without this check, we would only be aware of the closed connection
# by calling an sqlite3 function afterwards. However, it is important
# that we report the error before leaving executemany() call.
#
# Regression test for https://github.com/python/cpython/issues/143198.

cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)

class PT:
def __init__(self, value):
self.value = value
def __getitem__(self, i):
if self.value == j:
cx.close()
return self.value
def __len__(self):
return 1

cu = cx.cursor()
msg = r"Cannot operate on a closed database\."
items = iter(wtype(map(PT, range(n))))
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cu.executemany("insert into tmp(a) values (?)", items)

def test_fetch_iter(self):
# Optional DB-API extension.
self.cu.execute("delete from test")
Expand Down Expand Up @@ -1711,6 +1795,24 @@ def test_connection_execute(self):
result = self.con.execute("select 5").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.execute")

def test_connection_execute_use_after_close_with_bind_parameters(self):
# See CursorTests.test_execute_use_after_close_with_bind_parameters().

cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)

class PT:
def __getitem__(self, i):
cx.close()
return 1
def __len__(self):
return 1

msg = r"Cannot operate on a closed database\."
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cx.execute("insert into tmp(a) values (?)", PT())

def test_connection_executemany(self):
con = self.con
con.execute("create table test(foo)")
Expand All @@ -1719,6 +1821,43 @@ def test_connection_executemany(self):
self.assertEqual(result[0][0], 3, "Basic test of Connection.executemany")
self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")

@subTests("params_class", (ParamsCxCloseInIterMany, ParamsCxCloseInNext))
def test_connection_executemany_use_after_close(self, params_class):
# Prevent SIGSEGV with iterable of parameters closing the connection.
# Regression test for https://github.com/python/cpython/issues/143198.
cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)
msg = r"Cannot operate on a closed database\."
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cx.executemany("insert into tmp(a) values (?)", params_class(cx))

@subTests(("j", "n"), ([0, 1], [0, 3], [1, 3], [2, 3]))
@subTests("wtype", (list, lambda x: x))
def test_connection_executemany_use_after_close_with_bind_parameters(
self, j, n, wtype,
):
# See CursorTests.test_executemany_use_after_close_with_bind_parameters().

cx = sqlite.connect(":memory:")
cx.execute("create table tmp(a number)")
self.addCleanup(cx.close)

class PT:
def __init__(self, value):
self.value = value
def __getitem__(self, i):
if self.value == j:
cx.close()
return self.value
def __len__(self):
return 1

items = iter(wtype(map(PT, range(n))))
msg = r"Cannot operate on a closed database\."
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cx.executemany("insert into tmp(a) values (?)", items)

def test_connection_executescript(self):
con = self.con
con.executescript("""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
:mod:`sqlite3`: fix crashes in :meth:`Connection.executemany <sqlite3.Connection.executemany>`
and :meth:`Cursor.executemany <sqlite3.Cursor.executemany>` when iterating over
the query's parameters closes the current connection. Patch by Bénédikt Tran.
40 changes: 15 additions & 25 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ static void free_callback_context(callback_context *ctx);
static void set_callback_context(callback_context **ctx_pp,
callback_context *ctx);
static int connection_close(pysqlite_Connection *self);
PyObject *_pysqlite_query_execute(pysqlite_Cursor *, int, PyObject *, PyObject *);

extern int _pysqlite_query_execute(pysqlite_Cursor *, PyObject *, PyObject *);
extern int _pysqlite_query_executemany(pysqlite_Cursor *, PyObject *, PyObject *);

static PyObject *
new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
Expand Down Expand Up @@ -1853,21 +1855,15 @@ pysqlite_connection_execute_impl(pysqlite_Connection *self, PyObject *sql,
PyObject *parameters)
/*[clinic end generated code: output=5be05ae01ee17ee4 input=27aa7792681ddba2]*/
{
PyObject* result = 0;

PyObject *cursor = pysqlite_connection_cursor_impl(self, NULL);
if (!cursor) {
goto error;
if (cursor == NULL) {
return NULL;
}

result = _pysqlite_query_execute((pysqlite_Cursor *)cursor, 0, sql, parameters);
if (!result) {
Py_CLEAR(cursor);
int rc = _pysqlite_query_execute((pysqlite_Cursor *)cursor, sql, parameters);
if (rc < 0) {
Py_DECREF(cursor);
return NULL;
}

error:
Py_XDECREF(result);

return cursor;
}

Expand All @@ -1886,21 +1882,15 @@ pysqlite_connection_executemany_impl(pysqlite_Connection *self,
PyObject *sql, PyObject *parameters)
/*[clinic end generated code: output=776cd2fd20bfe71f input=495be76551d525db]*/
{
PyObject* result = 0;

PyObject *cursor = pysqlite_connection_cursor_impl(self, NULL);
if (!cursor) {
goto error;
if (cursor == NULL) {
return NULL;
}

result = _pysqlite_query_execute((pysqlite_Cursor *)cursor, 1, sql, parameters);
if (!result) {
Py_CLEAR(cursor);
int rc = _pysqlite_query_executemany((pysqlite_Cursor *)cursor, sql, parameters);
if (rc < 0) {
Py_DECREF(cursor);
return NULL;
}

error:
Py_XDECREF(result);

return cursor;
}

Expand Down
Loading
Loading