Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions Lib/test/test_sqlite3/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,5 +503,72 @@ def test_recursive_cursor_iter(self):
self.cur.fetchall)


class CallbackClosesConnectionTests(unittest.TestCase):
"""Regression tests for gh-151030: callbacks that close the connection
during query execution must raise ProgrammingError, not crash."""

def _make_con(self) -> sqlite.Connection:
con = sqlite.connect(":memory:")
con.execute("CREATE TABLE t (v INTEGER)")
con.execute("INSERT INTO t VALUES (1)")
con.commit()
return con

def test_udf_closes_connection(self) -> None:
con = self._make_con()

def bad(x: int) -> int:
con.close()
return x

con.create_function("bad", 1, bad)
with self.assertRaises((sqlite.ProgrammingError, sqlite.OperationalError)):
con.execute("SELECT bad(v) FROM t").fetchall()

def test_progress_handler_closes_connection(self) -> None:
con = self._make_con()
fired = False

def handler() -> int:
nonlocal fired
if not fired:
fired = True
con.close()
return 0

con.set_progress_handler(handler, 1)
with self.assertRaises((sqlite.ProgrammingError, sqlite.OperationalError)):
con.execute("SELECT v FROM t").fetchall()

def test_trace_callback_closes_connection(self) -> None:
con = self._make_con()
fired = False

def tracer(statement: str) -> None:
nonlocal fired
if not fired:
fired = True
con.close()

con.set_trace_callback(tracer)
with self.assertRaises((sqlite.ProgrammingError, sqlite.OperationalError)):
con.execute("SELECT v FROM t").fetchall()

def test_authorizer_closes_connection(self) -> None:
con = self._make_con()
fired = False

def auth(action: int, arg1: str, arg2: str, db: str, trigger: str) -> int:
nonlocal fired
if not fired:
fired = True
con.close()
return sqlite.SQLITE_OK

con.set_authorizer(auth)
with self.assertRaises((sqlite.ProgrammingError, sqlite.OperationalError)):
con.execute("SELECT v FROM t").fetchall()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix crashes in :mod:`sqlite3` when a callback closes the connection
while a query is being prepared or executed.
4 changes: 4 additions & 0 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ connection_close(pysqlite_Connection *self)
sqlite3 *db = self->db;
self->db = NULL;

/* Unregister callbacks before closing so that SQLite cannot invoke them
* again after free_callback_contexts releases their contexts. */
remove_callbacks(db);

Py_BEGIN_ALLOW_THREADS
/* The v2 close call always returns SQLITE_OK if given a valid database
* pointer (which we do), so we can safely ignore the return value */
Expand Down
24 changes: 18 additions & 6 deletions Modules/_sqlite/cursor.c
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ _pysqlite_fetch_one_row(pysqlite_Cursor* self)
return NULL;

sqlite3 *db = self->connection->db;
if (db == NULL) {
pysqlite_state *state = self->connection->state;
PyErr_SetString(state->ProgrammingError,
"Cannot operate on a closed database.");
goto error;
}
for (i = 0; i < numcols; i++) {
if (self->connection->detect_types
&& self->row_cast_map != NULL
Expand Down Expand Up @@ -530,10 +536,16 @@ begin_transaction(pysqlite_Connection *self)
static PyObject *
get_statement_from_cache(pysqlite_Cursor *self, PyObject *operation)
{
PyObject *args[] = { NULL, operation, }; // Borrowed ref.
PyObject *cache = self->connection->statement_cache;
PyObject *args[] = { NULL, operation, };

/* Hold a strong reference: a Python callback invoked during statement
* preparation (e.g. an authorizer) may close the connection, freeing
* the cache while the call is still in progress. */
PyObject *cache = Py_NewRef(self->connection->statement_cache);
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
return PyObject_Vectorcall(cache, args + 1, nargsf, NULL);
PyObject *result = PyObject_Vectorcall(cache, args + 1, nargsf, NULL);
Py_DECREF(cache);
return result;
}

static inline int
Expand Down Expand Up @@ -957,7 +969,7 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation
}

if (rc == SQLITE_DONE) {
if (self->statement->is_dml) {
if (self->statement->is_dml && self->connection->db) {
self->rowcount += (long)sqlite3_changes(self->connection->db);
}
if (stmt_reset(self->statement) != SQLITE_OK) {
Expand All @@ -967,7 +979,7 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation
Py_XDECREF(parameters);
}

if (!multiple) {
if (!multiple && self->connection->db) {
sqlite_int64 lastrowid;

Py_BEGIN_ALLOW_THREADS
Expand Down Expand Up @@ -1157,7 +1169,7 @@ pysqlite_cursor_iternext(PyObject *op)
}
int rc = stmt_step(stmt);
if (rc == SQLITE_DONE) {
if (self->statement->is_dml) {
if (self->statement->is_dml && self->connection->db) {
self->rowcount = (long)sqlite3_changes(self->connection->db);
}
rc = stmt_reset(self->statement);
Expand Down
5 changes: 5 additions & 0 deletions Modules/_sqlite/util.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ set_error_from_code(pysqlite_state *state, int code)
int
set_error_from_db(pysqlite_state *state, sqlite3 *db)
{
if (db == NULL) {
PyErr_SetString(state->ProgrammingError,
"Cannot operate on a closed database.");
return SQLITE_MISUSE;
}
int errorcode = sqlite3_errcode(db);
PyObject *exc_class = get_exception_class(state, errorcode);
if (exc_class == NULL) {
Expand Down
Loading