From bba098db4901e51b85fd1c5659696648fa7b1413 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 21 Jul 2025 18:57:40 -0400 Subject: [PATCH 1/6] transactioner --- chia/_tests/db/test_sqlite_wrapper.py | 604 ++++++++++++++++++++++++++ chia/util/sqlite_wrapper.py | 208 +++++++++ chia/util/transactioner.py | 376 ++++++++++++++++ 3 files changed, 1188 insertions(+) create mode 100644 chia/_tests/db/test_sqlite_wrapper.py create mode 100644 chia/util/sqlite_wrapper.py create mode 100644 chia/util/transactioner.py diff --git a/chia/_tests/db/test_sqlite_wrapper.py b/chia/_tests/db/test_sqlite_wrapper.py new file mode 100644 index 000000000000..5fb0fe093b91 --- /dev/null +++ b/chia/_tests/db/test_sqlite_wrapper.py @@ -0,0 +1,604 @@ +from __future__ import annotations + +import asyncio +import contextlib +import tempfile +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional + +import aiosqlite +import pytest + +# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469 +from _pytest.fixtures import SubRequest + +from chia._tests.util.misc import Marks, boolean_datacases, datacases +from chia.util import sqlite_wrapper +from chia.util.sqlite_wrapper import SqliteConnection +from chia.util.task_referencer import create_referenced_task +from chia.util.transactioner import ( + ForeignKeyError, + InternalError, + NestedForeignKeyDelayedRequestError, + Transactioner, + generate_in_memory_db_uri, +) + +DBWrapper2 = Transactioner[SqliteConnection] + + +@asynccontextmanager +async def DBConnection( + db_version: int, + foreign_keys: Optional[bool] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> AsyncIterator[DBWrapper2]: + db_uri = generate_in_memory_db_uri() + async with sqlite_wrapper.managed( + database=db_uri, + uri=True, + reader_count=4, + db_version=db_version, + foreign_keys=foreign_keys, + row_factory=row_factory, + ) as _db_wrapper: + yield _db_wrapper + + +@asynccontextmanager +async def PathDBConnection(db_version: int) -> AsyncIterator[DBWrapper2]: + with tempfile.TemporaryDirectory() as directory: + db_path = Path(directory).joinpath("db.sqlite") + async with sqlite_wrapper.managed(database=db_path, reader_count=4, db_version=db_version) as _db_wrapper: + yield _db_wrapper + + +if TYPE_CHECKING: + ConnectionContextManager = contextlib.AbstractAsyncContextManager[SqliteConnection] + GetReaderMethod = Callable[[DBWrapper2], Callable[[], ConnectionContextManager]] + + +class UniqueError(Exception): + """Used to uniquely trigger the exception path out of the context managers.""" + + +async def increment_counter(db_wrapper: DBWrapper2) -> None: + async with db_wrapper.writer_maybe_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [old_value] = row + + await asyncio.sleep(0) + + new_value = old_value + 1 + await connection.execute("UPDATE counter SET value = :value", {"value": new_value}) + + +async def decrement_counter(db_wrapper: DBWrapper2) -> None: + async with db_wrapper.writer_maybe_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [old_value] = row + + await asyncio.sleep(0) + + new_value = old_value - 1 + await connection.execute("UPDATE counter SET value = :value", {"value": new_value}) + + +async def sum_counter(db_wrapper: DBWrapper2, output: list[int]) -> None: + async with db_wrapper.reader_no_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [value] = row + + output.append(value) + + +async def setup_table(db: DBWrapper2) -> None: + async with db.writer_maybe_transaction() as conn: + await conn.execute("CREATE TABLE counter(value INTEGER NOT NULL)") + await conn.execute("INSERT INTO counter(value) VALUES(0)") + + +async def get_value(cursor: aiosqlite.Cursor) -> int: + row = await cursor.fetchone() + assert row + return int(row[0]) + + +async def query_value(connection: SqliteConnection) -> int: + async with connection.execute("SELECT value FROM counter") as cursor: + return await get_value(cursor=cursor) + + +def _get_reader_no_transaction_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]: + return db_wrapper.reader_no_transaction + + +def _get_regular_reader_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]: + return db_wrapper.reader + + +@pytest.fixture( + name="get_reader_method", + params=[ + pytest.param(_get_reader_no_transaction_method, id="reader_no_transaction"), + pytest.param(_get_regular_reader_method, id="reader"), + ], +) +def get_reader_method_fixture(request: SubRequest) -> Callable[[], ConnectionContextManager]: + # https://github.com/pytest-dev/pytest/issues/8763 + return request.param # type: ignore[no-any-return] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(db_wrapper.writer_maybe_transaction()) + + tasks = [] + for index in range(concurrent_task_count): + task = create_referenced_task(increment_counter(db_wrapper)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + async with get_reader_method(db_wrapper)() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [value] = row + + assert value == concurrent_task_count + + +@pytest.mark.anyio +async def test_writers_nests() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + async with db_wrapper.writer_maybe_transaction() as conn1: + async with conn1.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + async with db_wrapper.writer_maybe_transaction() as conn2: + assert conn1 == conn2 + value += 1 + await conn2.execute("UPDATE counter SET value = :value", {"value": value}) + async with db_wrapper.writer_maybe_transaction() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 1 + + +@pytest.mark.anyio +async def test_writer_journal_mode_wal() -> None: + async with PathDBConnection(2) as db_wrapper: + async with db_wrapper.writer() as connection: + async with connection.execute("PRAGMA journal_mode") as cursor: + result = await cursor.fetchone() + assert result == ("wal",) + + +@pytest.mark.anyio +async def test_reader_journal_mode_wal() -> None: + async with PathDBConnection(2) as db_wrapper: + async with db_wrapper.reader_no_transaction() as connection: + async with connection.execute("PRAGMA journal_mode") as cursor: + result = await cursor.fetchone() + assert result == ("wal",) + + +@pytest.mark.anyio +async def test_partial_failure() -> None: + values = [] + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + async with db_wrapper.writer() as conn1: + await conn1.execute("UPDATE counter SET value = 42") + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + try: + async with db_wrapper.writer() as conn2: + await conn2.execute("UPDATE counter SET value = 1337") + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + # this simulates a failure, which will cause a rollback of the + # write we just made, back to 42 + raise RuntimeError("failure within a sub-transaction") + except RuntimeError: + # we expect to get here + values.append(1) + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + + # the write of 1337 failed, and was restored to 42 + assert values == [42, 1337, 1, 42] + + +@pytest.mark.anyio +async def test_readers_nests(get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with get_reader_method(db_wrapper)() as conn1: + async with get_reader_method(db_wrapper)() as conn2: + assert conn1 == conn2 + async with get_reader_method(db_wrapper)() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 0 + + +@pytest.mark.anyio +async def test_readers_nests_writer(get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as conn1: + async with get_reader_method(db_wrapper)() as conn2: + assert conn1 == conn2 + async with db_wrapper.writer_maybe_transaction() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 0 + + +@pytest.mark.parametrize( + argnames="transactioned", + argvalues=[ + pytest.param(True, id="transaction"), + pytest.param(False, id="no transaction"), + ], +) +@pytest.mark.anyio +async def test_only_transactioned_reader_ignores_writer(transactioned: bool) -> None: + writer_committed = asyncio.Event() + reader_read = asyncio.Event() + + async def write() -> None: + try: + async with db_wrapper.writer() as writer: + assert reader is not writer + + await writer.execute("UPDATE counter SET value = 1") + finally: + writer_committed.set() + + await reader_read.wait() + + assert await query_value(connection=writer) == 1 + + async with PathDBConnection(2) as db_wrapper: + get_reader = db_wrapper.reader if transactioned else db_wrapper.reader_no_transaction + + await setup_table(db_wrapper) + + async with get_reader() as reader: + assert await query_value(connection=reader) == 0 + + task = create_referenced_task(write()) + await writer_committed.wait() + + assert await query_value(connection=reader) == 0 if transactioned else 1 + reader_read.set() + + await task + + async with get_reader() as reader: + assert await query_value(connection=reader) == 1 + + +@pytest.mark.anyio +async def test_reader_nests_and_ends_transaction() -> None: + async with DBConnection(2) as db_wrapper: + async with db_wrapper.reader() as reader: + assert reader.in_transaction + + async with db_wrapper.reader() as inner_reader: + assert inner_reader is reader + assert reader.in_transaction + + assert reader.in_transaction + + assert not reader.in_transaction + + +@pytest.mark.anyio +async def test_writer_in_reader_works() -> None: + async with PathDBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.reader() as reader: + async with db_wrapper.writer() as writer: + assert writer is not reader + await writer.execute("UPDATE counter SET value = 1") + assert await query_value(connection=writer) == 1 + assert await query_value(connection=reader) == 0 + + assert await query_value(connection=reader) == 0 + + +@pytest.mark.anyio +async def test_reader_transaction_is_deferred() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.reader() as reader: + async with db_wrapper.writer() as writer: + assert writer is not reader + await writer.execute("UPDATE counter SET value = 1") + assert await query_value(connection=writer) == 1 + + # The deferred transaction initiation results in the transaction starting + # here and thus reading the written value. + assert await query_value(connection=reader) == 1 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as connection: + await connection.execute("UPDATE counter SET value = 1") + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(get_reader_method(db_wrapper)()) + + tasks = [] + values: list[int] = [] + for index in range(concurrent_task_count): + task = create_referenced_task(sum_counter(db_wrapper, values)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + assert values == [1] * concurrent_task_count + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with PathDBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as connection: + await connection.execute("UPDATE counter SET value = 1") + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(get_reader_method(db_wrapper)()) + + tasks = [] + values: list[int] = [] + for index in range(concurrent_task_count): + task = create_referenced_task(increment_counter(db_wrapper)) + tasks.append(task) + task = create_referenced_task(decrement_counter(db_wrapper)) + tasks.append(task) + task = create_referenced_task(sum_counter(db_wrapper, values)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + # we increment and decrement the counter an equal number of times. It should + # end back at 1. + async with get_reader_method(db_wrapper)() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + assert row is not None + assert row[0] == 1 + + # it's possible all increments or all decrements are run first + assert len(values) == concurrent_task_count + for v in values: + assert v > -99 + assert v <= 100 + + +@pytest.mark.parametrize( + argnames=["manager_method", "expected"], + argvalues=[ + [DBWrapper2.writer, True], + [DBWrapper2.writer_maybe_transaction, True], + [DBWrapper2.reader, True], + [DBWrapper2.reader_no_transaction, False], + ], +) +@pytest.mark.anyio +async def test_in_transaction_as_expected( + manager_method: Callable[[DBWrapper2], ConnectionContextManager], + expected: bool, +) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with manager_method(db_wrapper) as connection: + assert connection.in_transaction == expected + + +@pytest.mark.anyio +async def test_cancelled_reader_does_not_cancel_writer() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer() as writer: + await writer.execute("UPDATE counter SET value = 1") + + with pytest.raises(UniqueError): + async with db_wrapper.reader() as _: + raise UniqueError + + assert await query_value(connection=writer) == 1 + + assert await query_value(connection=writer) == 1 + + +@boolean_datacases(name="initial", false="initially disabled", true="initially enabled") +@boolean_datacases(name="forced", false="forced disabled", true="forced enabled") +@pytest.mark.anyio +async def test_foreign_key_pragma_controlled_by_writer(initial: bool, forced: bool) -> None: + async with DBConnection(2, foreign_keys=initial) as db_wrapper: + async with db_wrapper.writer(foreign_key_enforcement_enabled=forced) as writer: + async with writer.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + assert result is not None + [actual] = result + + assert actual == (1 if forced else 0) + + +@pytest.mark.anyio +async def test_foreign_key_pragma_rolls_back_on_foreign_key_error() -> None: + async with DBConnection(2, foreign_keys=True, row_factory=aiosqlite.Row) as db_wrapper: + async with db_wrapper.writer() as writer: + async with writer.execute( + """ + CREATE TABLE people( + id INTEGER NOT NULL, + friend INTEGER, + PRIMARY KEY (id), + FOREIGN KEY (friend) REFERENCES people + ) + """ + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 1, "friend": None}, + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 2, "friend": 1}, + ): + pass + + # make sure the writer raises a foreign key error on exit + with pytest.raises(ForeignKeyError): + async with db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass + + # make sure a foreign key error can be detected here + with pytest.raises(ForeignKeyError): + await db_wrapper._check_foreign_keys() + + async with writer.execute("SELECT * FROM people WHERE id = 1") as cursor: + [person] = await cursor.fetchall() + + # make sure the delete was rolled back + assert dict(person) == {"id": 1, "friend": None} + + +@dataclass +class RowFactoryCase: + id: str + factory: Optional[type[aiosqlite.Row]] + marks: Marks = () + + +row_factory_cases: list[RowFactoryCase] = [ + RowFactoryCase(id="default named tuple", factory=None), + RowFactoryCase(id="aiosqlite row", factory=aiosqlite.Row), +] + + +@datacases(*row_factory_cases) +@pytest.mark.anyio +async def test_foreign_key_check_failure_error_message(case: RowFactoryCase) -> None: + async with DBConnection(2, foreign_keys=True, row_factory=case.factory) as db_wrapper: + async with db_wrapper.writer() as writer: + async with writer.execute( + """ + CREATE TABLE people( + id INTEGER NOT NULL, + friend INTEGER, + PRIMARY KEY (id), + FOREIGN KEY (friend) REFERENCES people + ) + """ + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 1, "friend": None}, + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 2, "friend": 1}, + ): + pass + + # make sure the writer raises a foreign key error on exit + with pytest.raises(ForeignKeyError) as error: + async with db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass + + assert error.value.violations == [{"table": "people", "rowid": 2, "parent": "people", "fkid": 0}] + + +@pytest.mark.anyio +async def test_set_foreign_keys_fails_within_acquired_writer() -> None: + async with DBConnection(2, foreign_keys=True) as db_wrapper: + async with db_wrapper.writer(): + with pytest.raises( + InternalError, + match="Unable to set foreign key enforcement state while a writer is held", + ): + async with db_wrapper._set_foreign_key_enforcement(enabled=False): + pass # pragma: no cover + + +@boolean_datacases(name="initial", false="initially disabled", true="initially enabled") +@pytest.mark.anyio +async def test_delayed_foreign_key_request_fails_when_nested(initial: bool) -> None: + async with DBConnection(2, foreign_keys=initial) as db_wrapper: + async with db_wrapper.writer(): + with pytest.raises(NestedForeignKeyDelayedRequestError): + async with db_wrapper.writer(foreign_key_enforcement_enabled=True): + pass # pragma: no cover diff --git a/chia/util/sqlite_wrapper.py b/chia/util/sqlite_wrapper.py new file mode 100644 index 000000000000..5e4014ea25bd --- /dev/null +++ b/chia/util/sqlite_wrapper.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import contextlib +import functools +from collections.abc import AsyncIterator +from dataclasses import dataclass +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Self, TextIO, Union, cast + +import aiosqlite +import aiosqlite.context +import anyio +from aiosqlite import Cursor + +from chia.util.transactioner import ( + ConnectionProtocol, + CreateConnectionCallable, + Transactioner, + manage_connection, + sql_trace_callback, + # CursorProtocol, AwaitableEnterable, +) + +# if TYPE_CHECKING: +# _protocol_check: AwaitableEnterable[Cursor] = cast(aiosqlite.Cursor, None) + + +@dataclass +class SqliteConnection: + if TYPE_CHECKING: + _protocol_check: ClassVar[ConnectionProtocol[aiosqlite.context.Result[aiosqlite.Cursor]]] = cast( + "SqliteConnection", None + ) + + _connection: aiosqlite.Connection + + async def close(self) -> None: + return await self._connection.close() + + def execute(self, *args: Any, **kwargs: Any) -> aiosqlite.context.Result[Cursor]: + return self._connection.execute(*args, **kwargs) + + @property + def in_transaction(self) -> bool: + return self._connection.in_transaction + + async def rollback(self) -> None: + await self._connection.rollback() + + async def configure_as_reader(self) -> None: + await self.execute("pragma query_only") + + async def read_transaction(self) -> None: + await self.execute("BEGIN DEFERRED;") + + async def __aenter__(self) -> Self: + await self._connection.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + @contextlib.asynccontextmanager + async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: + await self._connection.execute(f"SAVEPOINT {name}") + try: + yield + except: + await self._connection.execute(f"ROLLBACK TO {name}") + raise + finally: + # rollback to a savepoint doesn't cancel the transaction, it + # just rolls back the state. We need to cancel it regardless + await self._connection.execute(f"RELEASE {name}") + + +async def sqlite_create_connection( + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> SqliteConnection: + # To avoid https://github.com/python/cpython/issues/118172 + connection = await aiosqlite.connect(database=database, uri=uri, cached_statements=0) + + if log_file is not None: + await connection.set_trace_callback(functools.partial(sql_trace_callback, file=log_file, name=name)) + + if row_factory is not None: + connection.row_factory = row_factory + + return SqliteConnection(_connection=connection) + + +@contextlib.asynccontextmanager +async def managed( + database: Union[str, Path], + *, + db_version: int = 1, + uri: bool = False, + reader_count: int = 4, + log_path: Optional[Path] = None, + journal_mode: str = "WAL", + synchronous: Optional[str] = None, + foreign_keys: Optional[bool] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> AsyncIterator[Transactioner[SqliteConnection]]: + if foreign_keys is None: + foreign_keys = False + + async with contextlib.AsyncExitStack() as async_exit_stack: + if log_path is None: + log_file = None + else: + log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = async_exit_stack.enter_context(log_path.open("a", encoding="utf-8")) + + write_connection = await async_exit_stack.enter_async_context( + manage_connection( + create_connection=sqlite_create_connection, database=database, uri=uri, log_file=log_file, name="writer" + ), + ) + await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close() + if synchronous is not None: + await (await write_connection.execute(f"pragma synchronous={synchronous}")).close() + + await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close() + + write_connection._connection.row_factory = row_factory + + self = Transactioner( + create_connection=sqlite_create_connection, + _write_connection=write_connection, + db_version=db_version, + _log_file=log_file, + ) + + for index in range(reader_count): + read_connection = await async_exit_stack.enter_async_context( + manage_connection( + create_connection=sqlite_create_connection, + database=database, + uri=uri, + log_file=log_file, + name=f"reader-{index}", + ), + ) + read_connection._connection.row_factory = row_factory + + await self.add_connection(c=read_connection) + try: + yield self + finally: + with anyio.CancelScope(shield=True): + while self._num_read_connections > 0: + await self._read_connections.get() + self._num_read_connections -= 1 + + +async def create( + create_connection: CreateConnectionCallable[SqliteConnection], + database: Union[str, Path], + *, + db_version: int = 1, + uri: bool = False, + reader_count: int = 4, + log_path: Optional[Path] = None, + journal_mode: str = "WAL", + synchronous: Optional[str] = None, + foreign_keys: bool = False, +) -> Transactioner[SqliteConnection]: + # WARNING: please use .managed() instead + if log_path is None: + log_file = None + else: + log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = log_path.open("a", encoding="utf-8") + write_connection = await create_connection(database=database, uri=uri, log_file=log_file, name="writer") + await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close() + if synchronous is not None: + await (await write_connection.execute(f"pragma synchronous={synchronous}")).close() + + await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close() + + self = Transactioner( + create_connection=create_connection, + _write_connection=write_connection, + db_version=db_version, + _log_file=log_file, + ) + + for index in range(reader_count): + read_connection = await create_connection( + database=database, + uri=uri, + log_file=log_file, + name=f"reader-{index}", + ) + await self.add_connection(c=read_connection) + + return self diff --git a/chia/util/transactioner.py b/chia/util/transactioner.py new file mode 100644 index 000000000000..9593b037291d --- /dev/null +++ b/chia/util/transactioner.py @@ -0,0 +1,376 @@ +# Package: utils + +from __future__ import annotations + +import asyncio +import contextlib +import secrets +import sqlite3 +import sys +from collections.abc import AsyncIterator, Iterable +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from types import TracebackType +from typing import Any, Generic, Optional, Protocol, TextIO, TypeVar, Union + +import aiosqlite +import anyio +from typing_extensions import Self, final + +if aiosqlite.sqlite_version_info < (3, 32, 0): + SQLITE_MAX_VARIABLE_NUMBER = 900 +else: + SQLITE_MAX_VARIABLE_NUMBER = 32700 + +# integers in sqlite are limited by int64 +SQLITE_INT_MAX = 2**63 - 1 + + +class DBWrapperError(Exception): + pass + + +class ForeignKeyError(DBWrapperError): + def __init__(self, violations: Iterable[Union[aiosqlite.Row, tuple[str, object, str, object]]]) -> None: + self.violations: list[dict[str, object]] = [] + + for violation in violations: + if isinstance(violation, tuple): + violation_dict = dict(zip(["table", "rowid", "parent", "fkid"], violation)) + else: + violation_dict = dict(violation) + self.violations.append(violation_dict) + + super().__init__(f"Found {len(self.violations)} FK violations: {self.violations}") + + +class NestedForeignKeyDelayedRequestError(DBWrapperError): + def __init__(self) -> None: + super().__init__("Unable to enable delayed foreign key enforcement in a nested request.") + + +class InternalError(DBWrapperError): + pass + + +class PurposefulAbort(DBWrapperError): + obj: object + + def __init__(self, obj: object) -> None: + self.obj = obj + + +def generate_in_memory_db_uri() -> str: + # We need to use shared cache as our DB wrapper uses different types of connections + return f"file:db_{secrets.token_hex(16)}?mode=memory&cache=shared" + + +async def execute_fetchone( + c: aiosqlite.Connection, sql: str, parameters: Optional[Iterable[Any]] = None +) -> Optional[sqlite3.Row]: + rows = await c.execute_fetchall(sql, parameters) + for row in rows: + return row + return None + + +def sql_trace_callback(req: str, file: TextIO, name: Optional[str] = None) -> None: + timestamp = datetime.now().strftime("%H:%M:%S.%f") + if name is not None: + line = f"{timestamp} {name} {req}\n" + else: + line = f"{timestamp} {req}\n" + file.write(line) + + +def get_host_parameter_limit() -> int: + # NOTE: This does not account for dynamically adjusted limits since it makes a + # separate db and connection. If aiosqlite adds support we should use it. + if sys.version_info >= (3, 11): + connection = sqlite3.connect(":memory:") + + limit_number = sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER + host_parameter_limit = connection.getlimit(limit_number) + else: + # guessing based on defaults, seems you can't query + + # https://www.sqlite.org/changes.html#version_3_32_0 + # Increase the default upper bound on the number of parameters from 999 to 32766. + if sqlite3.sqlite_version_info >= (3, 32, 0): + host_parameter_limit = 32766 + else: + host_parameter_limit = 999 + return host_parameter_limit + + +# class CursorProtocol(Protocol): +# async def close(self) -> None: ... +# # async def __aenter__(self) -> Self: ... +# # async def __aexit__( +# # self, +# # exc_type: Optional[type[BaseException]], +# # exc_val: Optional[BaseException], +# # exc_tb: Optional[TracebackType], +# # ) -> None: ... +# +# +# TCursor_co = TypeVar("TCursor_co", bound=CursorProtocol, covariant=True) +# +# class AwaitableEnterable(Protocol[TCursor_co]): +# def __await__(self) -> TCursor_co: ... +# async def __aenter__(self) -> TCursor_co: ... +# async def __aexit__( +# self, +# exc_type: Optional[type[BaseException]], +# exc_val: Optional[BaseException], +# exc_tb: Optional[TracebackType], +# ) -> None: ... + + +T_co = TypeVar("T_co", covariant=True) + + +class ConnectionProtocol(Protocol[T_co]): + # TODO: this is presently matching aiosqlite.Connection, generalize + + async def close(self) -> None: ... + def execute(self, *args: Any, **kwargs: Any) -> T_co: ... + @property + def in_transaction(self) -> bool: ... + async def rollback(self) -> None: ... + async def configure_as_reader(self) -> None: ... + async def read_transaction(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... + + @contextlib.asynccontextmanager + async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: + yield + + +# a_cursor: CursorProtocol = cast(aiosqlite.Cursor, None) + +# TODO: are these ok missing type parameters so they stay generic? or... +TConnection = TypeVar("TConnection", bound=ConnectionProtocol) # type: ignore[type-arg] +TConnection_co = TypeVar("TConnection_co", bound=ConnectionProtocol, covariant=True) # type: ignore[type-arg] + + +class CreateConnectionCallable(Protocol[TConnection_co]): + async def __call__( + self, + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, + ) -> TConnection_co: ... + + +@contextlib.asynccontextmanager +async def manage_connection( + create_connection: CreateConnectionCallable[TConnection_co], + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, +) -> AsyncIterator[TConnection_co]: + connection: TConnection_co + connection = await create_connection(database=database, uri=uri, log_file=log_file, name=name) + + try: + yield connection + finally: + with anyio.CancelScope(shield=True): + await connection.close() + + +@final +@dataclass +class Transactioner(Generic[TConnection]): + create_connection: CreateConnectionCallable[TConnection] + _write_connection: TConnection + db_version: int = 1 + _log_file: Optional[TextIO] = None + host_parameter_limit: int = get_host_parameter_limit() + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _read_connections: asyncio.Queue[TConnection] = field(default_factory=asyncio.Queue) + _num_read_connections: int = 0 + _in_use: dict[asyncio.Task[object], TConnection] = field(default_factory=dict) + _current_writer: Optional[asyncio.Task[object]] = None + _savepoint_name: int = 0 + + async def add_connection(self, c: TConnection) -> None: + # this guarantees that reader connections can only be used for reading + assert c != self._write_connection + await c.configure_as_reader() + self._read_connections.put_nowait(c) + self._num_read_connections += 1 + + async def close(self) -> None: + # WARNING: please use .managed() instead + try: + while self._num_read_connections > 0: + await (await self._read_connections.get()).close() + self._num_read_connections -= 1 + await self._write_connection.close() + finally: + if self._log_file is not None: + self._log_file.close() + + def _next_savepoint(self) -> str: + name = f"s{self._savepoint_name}" + self._savepoint_name += 1 + return name + + @contextlib.asynccontextmanager + async def writer( + self, + foreign_key_enforcement_enabled: Optional[bool] = None, + ) -> AsyncIterator[TConnection]: + """ + Initiates a new, possibly nested, transaction. If this task is already + in a transaction, none of the changes made as part of this transaction + will become visible to others until that top level transaction commits. + If this transaction fails (by exiting the context manager with an + exception) this transaction will be rolled back, but the next outer + transaction is not necessarily cancelled. It would also need to exit + with an exception to be cancelled. + The sqlite features this relies on are SAVEPOINT, ROLLBACK TO and RELEASE. + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + # we allow nesting writers within the same task + if foreign_key_enforcement_enabled is not None: + # NOTE: Technically this is complaining even if the requested state is + # already in place. This could be adjusted to allow nesting + # when the existing and requested states agree. In this case, + # probably skip the nested foreign key check when exiting since + # we don't have many foreign key errors and so it is likely ok + # to save the extra time checking twice. + raise NestedForeignKeyDelayedRequestError + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + yield self._write_connection + return + + async with self._lock: + async with contextlib.AsyncExitStack() as exit_stack: + if foreign_key_enforcement_enabled is not None: + await exit_stack.enter_async_context( + self._set_foreign_key_enforcement(enabled=foreign_key_enforcement_enabled), + ) + + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + self._current_writer = task + try: + yield self._write_connection + + if foreign_key_enforcement_enabled is not None and not foreign_key_enforcement_enabled: + await self._check_foreign_keys() + finally: + self._current_writer = None + + @contextlib.asynccontextmanager + async def _set_foreign_key_enforcement(self, enabled: bool) -> AsyncIterator[None]: + if self._current_writer is not None: + raise InternalError("Unable to set foreign key enforcement state while a writer is held") + + async with self._write_connection.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + if result is None: # pragma: no cover + raise InternalError("No results when querying for present foreign key enforcement state") + [original_value] = result + + if original_value == enabled: + yield + return + + try: + await self._write_connection.execute(f"PRAGMA foreign_keys={enabled}") + yield + finally: + with anyio.CancelScope(shield=True): + await self._write_connection.execute(f"PRAGMA foreign_keys={original_value}") + + async def _check_foreign_keys(self) -> None: + async with self._write_connection.execute("PRAGMA foreign_key_check") as cursor: + violations = list(await cursor.fetchall()) + + if len(violations) > 0: + raise ForeignKeyError(violations=violations) + + @contextlib.asynccontextmanager + async def writer_maybe_transaction(self) -> AsyncIterator[TConnection]: + """ + Initiates a write to the database. If this task is already in a write + transaction with the DB, this is a no-op. Any changes made to the + database will be rolled up into the transaction we're already in. If the + current task is not already in a transaction, one will be created and + committed (or rolled back in the case of an exception). + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + # just use the existing transaction + yield self._write_connection + return + + async with self._lock: + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + self._current_writer = task + try: + yield self._write_connection + finally: + self._current_writer = None + + @contextlib.asynccontextmanager + async def reader(self) -> AsyncIterator[TConnection]: + async with self.reader_no_transaction() as connection: + if connection.in_transaction: + yield connection + else: + await connection.read_transaction() + try: + yield connection + finally: + # close the transaction with a rollback instead of commit just in + # case any modifications were submitted through this reader + await connection.rollback() + + @contextlib.asynccontextmanager + async def reader_no_transaction(self) -> AsyncIterator[TConnection]: + # there should have been read connections added + assert self._num_read_connections > 0 + + # we can have multiple concurrent readers, just pick a connection from + # the pool of readers. If they're all busy, we'll wait for one to free + # up. + task = asyncio.current_task() + assert task is not None + + # if this task currently holds the write lock, use the same connection, + # so it can read back updates it has made to its transaction, even + # though it hasn't been committed yet + if self._current_writer == task: + # we allow nesting reading while also having a writer connection + # open, within the same task + yield self._write_connection + return + + if task in self._in_use: + yield self._in_use[task] + else: + c = await self._read_connections.get() + try: + # record our connection in this dict to allow nested calls in + # the same task to use the same connection + self._in_use[task] = c + yield c + finally: + del self._in_use[task] + self._read_connections.put_nowait(c) From d433d2cdd6f0f229f0e490e7dd120fe862a92fd3 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 21 Jul 2025 19:23:27 -0400 Subject: [PATCH 2/6] switch over datalayer --- chia/data_layer/data_layer_util.py | 5 +++-- chia/data_layer/data_store.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index 517b763517a0..e6cabe200169 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -14,9 +14,10 @@ from chia.data_layer.data_layer_errors import ProofIntegrityError from chia.server.ws_connection import WSChiaConnection from chia.types.blockchain_format.program import Program +from chia.util import sqlite_wrapper from chia.util.byte_types import hexstr_to_bytes -from chia.util.db_wrapper import DBWrapper2 from chia.util.streamable import Streamable, streamable +from chia.util.transactioner import Transactioner from chia.wallet.db_wallet.db_wallet_puzzles import create_host_fullpuz if TYPE_CHECKING: @@ -77,7 +78,7 @@ def get_hashes_for_page(page: int, lengths: dict[bytes32, int], max_page_size: i return PaginationData(current_page + 1, total_bytes, hashes) -async def _debug_dump(db: DBWrapper2, description: str = "") -> None: +async def _debug_dump(db: Transactioner[sqlite_wrapper.SqliteConnection], description: str = "") -> None: async with db.reader() as reader: cursor = await reader.execute("SELECT name FROM sqlite_master WHERE type='table';") print("-" * 50, description, flush=True) diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 039b68973693..45216190bd18 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -42,7 +42,10 @@ unspecified, ) from chia.types.blockchain_format.program import Program -from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2 +from chia.util import sqlite_wrapper +from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER +from chia.util.sqlite_wrapper import SqliteConnection +from chia.util.transactioner import Transactioner log = logging.getLogger(__name__) @@ -55,14 +58,14 @@ class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" - db_wrapper: DBWrapper2 + db_wrapper: Transactioner[sqlite_wrapper.SqliteConnection] @classmethod @contextlib.asynccontextmanager async def managed( cls, database: Union[str, Path], uri: bool = False, sql_log_path: Optional[Path] = None ) -> AsyncIterator[DataStore]: - async with DBWrapper2.managed( + async with sqlite_wrapper.managed( database=database, uri=uri, journal_mode="WAL", @@ -756,7 +759,7 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3 async def get_keys_values_cursor( self, - reader: aiosqlite.Connection, + reader: SqliteConnection, root_hash: Optional[bytes32], only_keys: bool = False, ) -> aiosqlite.Cursor: @@ -1404,7 +1407,7 @@ async def upsert( return InsertResult(node_hash=new_terminal_node_hash, root=new_root) - async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None) -> None: + async def clean_node_table(self, writer: Optional[SqliteConnection] = None) -> None: query = """ WITH RECURSIVE pending_nodes AS ( SELECT node_hash AS hash FROM root From 043eeeffcc5075f49f72588bdc08527a5ed29a10 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 22 Jul 2025 09:44:47 -0400 Subject: [PATCH 3/6] fixup --- chia/util/sqlite_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chia/util/sqlite_wrapper.py b/chia/util/sqlite_wrapper.py index 5e4014ea25bd..8a058a198252 100644 --- a/chia/util/sqlite_wrapper.py +++ b/chia/util/sqlite_wrapper.py @@ -6,12 +6,13 @@ from dataclasses import dataclass from pathlib import Path from types import TracebackType -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Self, TextIO, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, TextIO, Union, cast import aiosqlite import aiosqlite.context import anyio from aiosqlite import Cursor +from typing_extensions import Self from chia.util.transactioner import ( ConnectionProtocol, From 2798b3edb37cd7822db44467601bc76daa624fa9 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 29 Jul 2025 15:02:32 -0400 Subject: [PATCH 4/6] extract delayed foreign key for sqlite --- chia/_tests/db/test_sqlite_wrapper.py | 56 ++++---- chia/data_layer/data_layer_util.py | 3 +- chia/data_layer/data_store.py | 189 +++++++++++++------------- chia/util/sqlite_wrapper.py | 89 +++++++++++- chia/util/transactioner.py | 114 ++++++---------- 5 files changed, 253 insertions(+), 198 deletions(-) diff --git a/chia/_tests/db/test_sqlite_wrapper.py b/chia/_tests/db/test_sqlite_wrapper.py index 5fb0fe093b91..8fd6e2422acf 100644 --- a/chia/_tests/db/test_sqlite_wrapper.py +++ b/chia/_tests/db/test_sqlite_wrapper.py @@ -17,17 +17,19 @@ from chia._tests.util.misc import Marks, boolean_datacases, datacases from chia.util import sqlite_wrapper -from chia.util.sqlite_wrapper import SqliteConnection +from chia.util.sqlite_wrapper import ( + ForeignKeyError, + NestedForeignKeyDelayedRequestError, + SqliteConnection, + SqliteTransactioner, +) from chia.util.task_referencer import create_referenced_task from chia.util.transactioner import ( - ForeignKeyError, InternalError, - NestedForeignKeyDelayedRequestError, - Transactioner, generate_in_memory_db_uri, ) -DBWrapper2 = Transactioner[SqliteConnection] +DBWrapper2 = SqliteTransactioner @asynccontextmanager @@ -477,11 +479,13 @@ async def test_cancelled_reader_does_not_cancel_writer() -> None: @pytest.mark.anyio async def test_foreign_key_pragma_controlled_by_writer(initial: bool, forced: bool) -> None: async with DBConnection(2, foreign_keys=initial) as db_wrapper: - async with db_wrapper.writer(foreign_key_enforcement_enabled=forced) as writer: - async with writer.execute("PRAGMA foreign_keys") as cursor: - result = await cursor.fetchone() - assert result is not None - [actual] = result + async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + assert result is not None + [actual] = result assert actual == (1 if forced else 0) @@ -516,13 +520,15 @@ async def test_foreign_key_pragma_rolls_back_on_foreign_key_error() -> None: # make sure the writer raises a foreign key error on exit with pytest.raises(ForeignKeyError): - async with db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - async with writer.execute("DELETE FROM people WHERE id = 1"): - pass + async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass - # make sure a foreign key error can be detected here - with pytest.raises(ForeignKeyError): - await db_wrapper._check_foreign_keys() + # make sure a foreign key error can be detected here + with pytest.raises(ForeignKeyError): + await writer._check_foreign_keys() async with writer.execute("SELECT * FROM people WHERE id = 1") as cursor: [person] = await cursor.fetchall() @@ -575,9 +581,11 @@ async def test_foreign_key_check_failure_error_message(case: RowFactoryCase) -> # make sure the writer raises a foreign key error on exit with pytest.raises(ForeignKeyError) as error: - async with db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - async with writer.execute("DELETE FROM people WHERE id = 1"): - pass + async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass assert error.value.violations == [{"table": "people", "rowid": 2, "parent": "people", "fkid": 0}] @@ -585,12 +593,12 @@ async def test_foreign_key_check_failure_error_message(case: RowFactoryCase) -> @pytest.mark.anyio async def test_set_foreign_keys_fails_within_acquired_writer() -> None: async with DBConnection(2, foreign_keys=True) as db_wrapper: - async with db_wrapper.writer(): + async with db_wrapper.writer() as writer: with pytest.raises( InternalError, match="Unable to set foreign key enforcement state while a writer is held", ): - async with db_wrapper._set_foreign_key_enforcement(enabled=False): + async with writer._set_foreign_key_enforcement(enabled=False): pass # pragma: no cover @@ -600,5 +608,7 @@ async def test_delayed_foreign_key_request_fails_when_nested(initial: bool) -> N async with DBConnection(2, foreign_keys=initial) as db_wrapper: async with db_wrapper.writer(): with pytest.raises(NestedForeignKeyDelayedRequestError): - async with db_wrapper.writer(foreign_key_enforcement_enabled=True): - pass # pragma: no cover + async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer(): + pass # pragma: no cover diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index e6cabe200169..14ce9a55ad29 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -17,7 +17,6 @@ from chia.util import sqlite_wrapper from chia.util.byte_types import hexstr_to_bytes from chia.util.streamable import Streamable, streamable -from chia.util.transactioner import Transactioner from chia.wallet.db_wallet.db_wallet_puzzles import create_host_fullpuz if TYPE_CHECKING: @@ -78,7 +77,7 @@ def get_hashes_for_page(page: int, lengths: dict[bytes32, int], max_page_size: i return PaginationData(current_page + 1, total_bytes, hashes) -async def _debug_dump(db: Transactioner[sqlite_wrapper.SqliteConnection], description: str = "") -> None: +async def _debug_dump(db: sqlite_wrapper.SqliteTransactioner, description: str = "") -> None: async with db.reader() as reader: cursor = await reader.execute("SELECT name FROM sqlite_master WHERE type='table';") print("-" * 50, description, flush=True) diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 45216190bd18..aa196c276e02 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -45,7 +45,6 @@ from chia.util import sqlite_wrapper from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER from chia.util.sqlite_wrapper import SqliteConnection -from chia.util.transactioner import Transactioner log = logging.getLogger(__name__) @@ -58,7 +57,7 @@ class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" - db_wrapper: Transactioner[sqlite_wrapper.SqliteConnection] + db_wrapper: sqlite_wrapper.SqliteTransactioner @classmethod @contextlib.asynccontextmanager @@ -200,25 +199,27 @@ async def migrate_db(self) -> None: version = "v1.0" log.info(f"Initiating migration to version {version}") - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await writer.execute( - f""" - CREATE TABLE IF NOT EXISTS new_root( - tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), - generation INTEGER NOT NULL CHECK(generation >= 0), - node_hash BLOB, - status INTEGER NOT NULL CHECK( - {" OR ".join(f"status == {status}" for status in Status)} - ), - PRIMARY KEY(tree_id, generation), - FOREIGN KEY(node_hash) REFERENCES node(hash) - ) - """ - ) - await writer.execute("INSERT INTO new_root SELECT * FROM root") - await writer.execute("DROP TABLE root") - await writer.execute("ALTER TABLE new_root RENAME TO root") - await writer.execute("INSERT INTO schema (version_id) VALUES (?)", (version,)) + async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await writer.execute( + f""" + CREATE TABLE IF NOT EXISTS new_root( + tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), + generation INTEGER NOT NULL CHECK(generation >= 0), + node_hash BLOB, + status INTEGER NOT NULL CHECK( + {" OR ".join(f"status == {status}" for status in Status)} + ), + PRIMARY KEY(tree_id, generation), + FOREIGN KEY(node_hash) REFERENCES node(hash) + ) + """ + ) + await writer.execute("INSERT INTO new_root SELECT * FROM root") + await writer.execute("DROP TABLE root") + await writer.execute("ALTER TABLE new_root RENAME TO root") + await writer.execute("INSERT INTO schema (version_id) VALUES (?)", (version,)) log.info(f"Finished migrating DB to version {version}") async def _insert_root( @@ -1431,8 +1432,10 @@ async def clean_node_table(self, writer: Optional[SqliteConnection] = None) -> N """ params = {"pending_status": Status.PENDING.value, "pending_batch_status": Status.PENDING_BATCH.value} if writer is None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await writer.execute(query, params) + async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await writer.execute(query, params) else: await writer.execute(query, params) @@ -2068,76 +2071,78 @@ async def remove_subscriptions(self, store_id: bytes32, urls: list[str]) -> None ) async def delete_store_data(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await self.clean_node_table(writer) - cursor = await writer.execute( - """ - WITH RECURSIVE all_nodes AS ( - SELECT a.hash, n.left, n.right - FROM ancestors AS a - JOIN node AS n ON a.hash = n.hash - WHERE a.tree_id = :tree_id - ), - pending_nodes AS ( - SELECT node_hash AS hash FROM root - WHERE status IN (:pending_status, :pending_batch_status) - UNION ALL - SELECT n.left FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.left IS NOT NULL - UNION ALL - SELECT n.right FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.right IS NOT NULL - ) + async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await self.clean_node_table(writer) + cursor = await writer.execute( + """ + WITH RECURSIVE all_nodes AS ( + SELECT a.hash, n.left, n.right + FROM ancestors AS a + JOIN node AS n ON a.hash = n.hash + WHERE a.tree_id = :tree_id + ), + pending_nodes AS ( + SELECT node_hash AS hash FROM root + WHERE status IN (:pending_status, :pending_batch_status) + UNION ALL + SELECT n.left FROM node n + INNER JOIN pending_nodes pn ON n.hash = pn.hash + WHERE n.left IS NOT NULL + UNION ALL + SELECT n.right FROM node n + INNER JOIN pending_nodes pn ON n.hash = pn.hash + WHERE n.right IS NOT NULL + ) - SELECT hash, left, right - FROM all_nodes - WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id) - AND hash NOT IN (SELECT hash from pending_nodes) - """, - { - "tree_id": store_id, - "pending_status": Status.PENDING.value, - "pending_batch_status": Status.PENDING_BATCH.value, - }, - ) - to_delete: dict[bytes, tuple[bytes, bytes]] = {} - ref_counts: dict[bytes, int] = {} - async for row in cursor: - hash = row["hash"] - left = row["left"] - right = row["right"] - if hash in to_delete: - prev_left, prev_right = to_delete[hash] - assert prev_left == left - assert prev_right == right - continue - to_delete[hash] = (left, right) - if left is not None: - ref_counts[left] = ref_counts.get(left, 0) + 1 - if right is not None: - ref_counts[right] = ref_counts.get(right, 0) + 1 - - await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (store_id,)) - await writer.execute("DELETE FROM root WHERE tree_id == ?", (store_id,)) - queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0] - while queue: - hash = queue.pop(0) - if hash not in to_delete: - continue - await writer.execute("DELETE FROM node WHERE hash == ?", (hash,)) - - left, right = to_delete[hash] - if left is not None: - ref_counts[left] -= 1 - if ref_counts[left] == 0: - queue.append(left) - - if right is not None: - ref_counts[right] -= 1 - if ref_counts[right] == 0: - queue.append(right) + SELECT hash, left, right + FROM all_nodes + WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id) + AND hash NOT IN (SELECT hash from pending_nodes) + """, + { + "tree_id": store_id, + "pending_status": Status.PENDING.value, + "pending_batch_status": Status.PENDING_BATCH.value, + }, + ) + to_delete: dict[bytes, tuple[bytes, bytes]] = {} + ref_counts: dict[bytes, int] = {} + async for row in cursor: + hash = row["hash"] + left = row["left"] + right = row["right"] + if hash in to_delete: + prev_left, prev_right = to_delete[hash] + assert prev_left == left + assert prev_right == right + continue + to_delete[hash] = (left, right) + if left is not None: + ref_counts[left] = ref_counts.get(left, 0) + 1 + if right is not None: + ref_counts[right] = ref_counts.get(right, 0) + 1 + + await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (store_id,)) + await writer.execute("DELETE FROM root WHERE tree_id == ?", (store_id,)) + queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0] + while queue: + hash = queue.pop(0) + if hash not in to_delete: + continue + await writer.execute("DELETE FROM node WHERE hash == ?", (hash,)) + + left, right = to_delete[hash] + if left is not None: + ref_counts[left] -= 1 + if ref_counts[left] == 0: + queue.append(left) + + if right is not None: + ref_counts[right] -= 1 + if ref_counts[right] == 0: + queue.append(right) async def unsubscribe(self, store_id: bytes32) -> None: async with self.db_wrapper.writer() as writer: diff --git a/chia/util/sqlite_wrapper.py b/chia/util/sqlite_wrapper.py index 8a058a198252..85db6927b1d7 100644 --- a/chia/util/sqlite_wrapper.py +++ b/chia/util/sqlite_wrapper.py @@ -2,11 +2,11 @@ import contextlib import functools -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass from pathlib import Path from types import TracebackType -from typing import TYPE_CHECKING, Any, ClassVar, Optional, TextIO, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, TextIO, TypeAlias, Union, cast import aiosqlite import aiosqlite.context @@ -17,16 +17,64 @@ from chia.util.transactioner import ( ConnectionProtocol, CreateConnectionCallable, + InternalError, Transactioner, manage_connection, sql_trace_callback, - # CursorProtocol, AwaitableEnterable, ) +SqliteTransactioner: TypeAlias = Transactioner["SqliteConnection", "UntransactionedSqliteConnection"] + # if TYPE_CHECKING: # _protocol_check: AwaitableEnterable[Cursor] = cast(aiosqlite.Cursor, None) +# TODO: think about the inheritance +# class NestedForeignKeyDelayedRequestError(DBWrapperError): +class NestedForeignKeyDelayedRequestError(Exception): + def __init__(self) -> None: + super().__init__("Unable to enable delayed foreign key enforcement in a nested request.") + + +# TODO: think about the inheritance +# class ForeignKeyError(DBWrapperError): +class ForeignKeyError(Exception): + def __init__(self, violations: Iterable[Union[aiosqlite.Row, tuple[str, object, str, object]]]) -> None: + self.violations: list[dict[str, object]] = [] + + for violation in violations: + if isinstance(violation, tuple): + violation_dict = dict(zip(["table", "rowid", "parent", "fkid"], violation)) + else: + violation_dict = dict(violation) + self.violations.append(violation_dict) + + super().__init__(f"Found {len(self.violations)} FK violations: {self.violations}") + + +@dataclass +class UntransactionedSqliteConnection: + _connection: SqliteConnection + + @contextlib.asynccontextmanager + async def delay(self, *, foreign_key_enforcement_enabled: bool) -> AsyncIterator[None]: + if self._connection.in_transaction and foreign_key_enforcement_enabled is not None: + # NOTE: Technically this is complaining even if the requested state is + # already in place. This could be adjusted to allow nesting + # when the existing and requested states agree. In this case, + # probably skip the nested foreign key check when exiting since + # we don't have many foreign key errors and so it is likely ok + # to save the extra time checking twice. + raise NestedForeignKeyDelayedRequestError + + async with self._connection._set_foreign_key_enforcement(enabled=foreign_key_enforcement_enabled): + async with self._connection.savepoint_ctx("delay"): + try: + yield + finally: + await self._connection._check_foreign_keys() + + @dataclass class SqliteConnection: if TYPE_CHECKING: @@ -55,6 +103,35 @@ async def configure_as_reader(self) -> None: async def read_transaction(self) -> None: await self.execute("BEGIN DEFERRED;") + @contextlib.asynccontextmanager + async def _set_foreign_key_enforcement(self, enabled: bool) -> AsyncIterator[None]: + if self.in_transaction: + raise InternalError("Unable to set foreign key enforcement state while a writer is held") + + async with self._connection.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + if result is None: # pragma: no cover + raise InternalError("No results when querying for present foreign key enforcement state") + [original_value] = result + + if original_value == enabled: + yield + return + + try: + await self._connection.execute(f"PRAGMA foreign_keys={enabled}") + yield + finally: + with anyio.CancelScope(shield=True): + await self._connection.execute(f"PRAGMA foreign_keys={original_value}") + + async def _check_foreign_keys(self) -> None: + async with self._connection.execute("PRAGMA foreign_key_check") as cursor: + violations = list(await cursor.fetchall()) + + if len(violations) > 0: + raise ForeignKeyError(violations=violations) + async def __aenter__(self) -> Self: await self._connection.__aenter__() return self @@ -112,7 +189,7 @@ async def managed( synchronous: Optional[str] = None, foreign_keys: Optional[bool] = None, row_factory: Optional[type[aiosqlite.Row]] = None, -) -> AsyncIterator[Transactioner[SqliteConnection]]: +) -> AsyncIterator[SqliteTransactioner]: if foreign_keys is None: foreign_keys = False @@ -138,6 +215,7 @@ async def managed( self = Transactioner( create_connection=sqlite_create_connection, + create_untransactioned_connection=UntransactionedSqliteConnection, _write_connection=write_connection, db_version=db_version, _log_file=log_file, @@ -176,7 +254,7 @@ async def create( journal_mode: str = "WAL", synchronous: Optional[str] = None, foreign_keys: bool = False, -) -> Transactioner[SqliteConnection]: +) -> SqliteTransactioner: # WARNING: please use .managed() instead if log_path is None: log_file = None @@ -192,6 +270,7 @@ async def create( self = Transactioner( create_connection=create_connection, + create_untransactioned_connection=UntransactionedSqliteConnection, _write_connection=write_connection, db_version=db_version, _log_file=log_file, diff --git a/chia/util/transactioner.py b/chia/util/transactioner.py index 9593b037291d..52c16d0194fe 100644 --- a/chia/util/transactioner.py +++ b/chia/util/transactioner.py @@ -12,7 +12,7 @@ from datetime import datetime from pathlib import Path from types import TracebackType -from typing import Any, Generic, Optional, Protocol, TextIO, TypeVar, Union +from typing import Any, Callable, Generic, Optional, Protocol, TextIO, TypeVar, Union import aiosqlite import anyio @@ -31,25 +31,6 @@ class DBWrapperError(Exception): pass -class ForeignKeyError(DBWrapperError): - def __init__(self, violations: Iterable[Union[aiosqlite.Row, tuple[str, object, str, object]]]) -> None: - self.violations: list[dict[str, object]] = [] - - for violation in violations: - if isinstance(violation, tuple): - violation_dict = dict(zip(["table", "rowid", "parent", "fkid"], violation)) - else: - violation_dict = dict(violation) - self.violations.append(violation_dict) - - super().__init__(f"Found {len(self.violations)} FK violations: {self.violations}") - - -class NestedForeignKeyDelayedRequestError(DBWrapperError): - def __init__(self) -> None: - super().__init__("Unable to enable delayed foreign key enforcement in a nested request.") - - class InternalError(DBWrapperError): pass @@ -159,6 +140,7 @@ async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: # TODO: are these ok missing type parameters so they stay generic? or... TConnection = TypeVar("TConnection", bound=ConnectionProtocol) # type: ignore[type-arg] TConnection_co = TypeVar("TConnection_co", bound=ConnectionProtocol, covariant=True) # type: ignore[type-arg] +TUntransactionedConnection = TypeVar("TUntransactionedConnection") class CreateConnectionCallable(Protocol[TConnection_co]): @@ -191,8 +173,9 @@ async def manage_connection( @final @dataclass -class Transactioner(Generic[TConnection]): +class Transactioner(Generic[TConnection, TUntransactionedConnection]): create_connection: CreateConnectionCallable[TConnection] + create_untransactioned_connection: Callable[[TConnection], TUntransactionedConnection] _write_connection: TConnection db_version: int = 1 _log_file: Optional[TextIO] = None @@ -228,10 +211,7 @@ def _next_savepoint(self) -> str: return name @contextlib.asynccontextmanager - async def writer( - self, - foreign_key_enforcement_enabled: Optional[bool] = None, - ) -> AsyncIterator[TConnection]: + async def writer_no_transaction(self) -> AsyncIterator[TUntransactionedConnection]: """ Initiates a new, possibly nested, transaction. If this task is already in a transaction, none of the changes made as part of this transaction @@ -245,64 +225,46 @@ async def writer( task = asyncio.current_task() assert task is not None if self._current_writer == task: - # we allow nesting writers within the same task - if foreign_key_enforcement_enabled is not None: - # NOTE: Technically this is complaining even if the requested state is - # already in place. This could be adjusted to allow nesting - # when the existing and requested states agree. In this case, - # probably skip the nested foreign key check when exiting since - # we don't have many foreign key errors and so it is likely ok - # to save the extra time checking twice. - raise NestedForeignKeyDelayedRequestError - async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): - yield self._write_connection + if self._write_connection.in_transaction: + raise Exception("can't nest for no transaction inside an active transaction") + + yield self.create_untransactioned_connection(self._write_connection) return async with self._lock: - async with contextlib.AsyncExitStack() as exit_stack: - if foreign_key_enforcement_enabled is not None: - await exit_stack.enter_async_context( - self._set_foreign_key_enforcement(enabled=foreign_key_enforcement_enabled), - ) - - async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): - self._current_writer = task - try: - yield self._write_connection - - if foreign_key_enforcement_enabled is not None and not foreign_key_enforcement_enabled: - await self._check_foreign_keys() - finally: - self._current_writer = None + self._current_writer = task + try: + yield self.create_untransactioned_connection(self._write_connection) + finally: + self._current_writer = None @contextlib.asynccontextmanager - async def _set_foreign_key_enforcement(self, enabled: bool) -> AsyncIterator[None]: - if self._current_writer is not None: - raise InternalError("Unable to set foreign key enforcement state while a writer is held") - - async with self._write_connection.execute("PRAGMA foreign_keys") as cursor: - result = await cursor.fetchone() - if result is None: # pragma: no cover - raise InternalError("No results when querying for present foreign key enforcement state") - [original_value] = result - - if original_value == enabled: - yield + async def writer(self) -> AsyncIterator[TConnection]: + """ + Initiates a new, possibly nested, transaction. If this task is already + in a transaction, none of the changes made as part of this transaction + will become visible to others until that top level transaction commits. + If this transaction fails (by exiting the context manager with an + exception) this transaction will be rolled back, but the next outer + transaction is not necessarily cancelled. It would also need to exit + with an exception to be cancelled. + The sqlite features this relies on are SAVEPOINT, ROLLBACK TO and RELEASE. + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + # we allow nesting writers within the same task + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + yield self._write_connection return - try: - await self._write_connection.execute(f"PRAGMA foreign_keys={enabled}") - yield - finally: - with anyio.CancelScope(shield=True): - await self._write_connection.execute(f"PRAGMA foreign_keys={original_value}") - - async def _check_foreign_keys(self) -> None: - async with self._write_connection.execute("PRAGMA foreign_key_check") as cursor: - violations = list(await cursor.fetchall()) - - if len(violations) > 0: - raise ForeignKeyError(violations=violations) + async with self._lock: + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + self._current_writer = task + try: + yield self._write_connection + finally: + self._current_writer = None @contextlib.asynccontextmanager async def writer_maybe_transaction(self) -> AsyncIterator[TConnection]: From fbd8315f36fc40b0a4f6bd3b38c4f9e702c5f2c8 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 29 Jul 2025 15:12:13 -0400 Subject: [PATCH 5/6] a few more sqlite specifics moved out of transactioner --- chia/_tests/db/test_sqlite_wrapper.py | 6 +- chia/util/sqlite_wrapper.py | 58 +++++++++++++++++- chia/util/transactioner.py | 87 +-------------------------- 3 files changed, 60 insertions(+), 91 deletions(-) diff --git a/chia/_tests/db/test_sqlite_wrapper.py b/chia/_tests/db/test_sqlite_wrapper.py index 8fd6e2422acf..b21f0e5324ee 100644 --- a/chia/_tests/db/test_sqlite_wrapper.py +++ b/chia/_tests/db/test_sqlite_wrapper.py @@ -22,12 +22,10 @@ NestedForeignKeyDelayedRequestError, SqliteConnection, SqliteTransactioner, -) -from chia.util.task_referencer import create_referenced_task -from chia.util.transactioner import ( - InternalError, generate_in_memory_db_uri, ) +from chia.util.task_referencer import create_referenced_task +from chia.util.transactioner import InternalError DBWrapper2 = SqliteTransactioner diff --git a/chia/util/sqlite_wrapper.py b/chia/util/sqlite_wrapper.py index 85db6927b1d7..9d17b5d95a47 100644 --- a/chia/util/sqlite_wrapper.py +++ b/chia/util/sqlite_wrapper.py @@ -2,8 +2,12 @@ import contextlib import functools +import secrets +import sqlite3 +import sys from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass +from datetime import datetime from pathlib import Path from types import TracebackType from typing import TYPE_CHECKING, Any, ClassVar, Optional, TextIO, TypeAlias, Union, cast @@ -20,7 +24,6 @@ InternalError, Transactioner, manage_connection, - sql_trace_callback, ) SqliteTransactioner: TypeAlias = Transactioner["SqliteConnection", "UntransactionedSqliteConnection"] @@ -29,6 +32,58 @@ # _protocol_check: AwaitableEnterable[Cursor] = cast(aiosqlite.Cursor, None) +if aiosqlite.sqlite_version_info < (3, 32, 0): + SQLITE_MAX_VARIABLE_NUMBER = 900 +else: + SQLITE_MAX_VARIABLE_NUMBER = 32700 + +# integers in sqlite are limited by int64 +SQLITE_INT_MAX = 2**63 - 1 + + +def generate_in_memory_db_uri() -> str: + # We need to use shared cache as our DB wrapper uses different types of connections + return f"file:db_{secrets.token_hex(16)}?mode=memory&cache=shared" + + +# async def execute_fetchone( +# c: aiosqlite.Connection, sql: str, parameters: Optional[Iterable[Any]] = None +# ) -> Optional[sqlite3.Row]: +# rows = await c.execute_fetchall(sql, parameters) +# for row in rows: +# return row +# return None + + +def sql_trace_callback(req: str, file: TextIO, name: Optional[str] = None) -> None: + timestamp = datetime.now().strftime("%H:%M:%S.%f") + if name is not None: + line = f"{timestamp} {name} {req}\n" + else: + line = f"{timestamp} {req}\n" + file.write(line) + + +def get_host_parameter_limit() -> int: + # NOTE: This does not account for dynamically adjusted limits since it makes a + # separate db and connection. If aiosqlite adds support we should use it. + if sys.version_info >= (3, 11): + connection = sqlite3.connect(":memory:") + + limit_number = sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER + host_parameter_limit = connection.getlimit(limit_number) + else: + # guessing based on defaults, seems you can't query + + # https://www.sqlite.org/changes.html#version_3_32_0 + # Increase the default upper bound on the number of parameters from 999 to 32766. + if sqlite3.sqlite_version_info >= (3, 32, 0): + host_parameter_limit = 32766 + else: + host_parameter_limit = 999 + return host_parameter_limit + + # TODO: think about the inheritance # class NestedForeignKeyDelayedRequestError(DBWrapperError): class NestedForeignKeyDelayedRequestError(Exception): @@ -83,6 +138,7 @@ class SqliteConnection: ) _connection: aiosqlite.Connection + host_parameter_limit: ClassVar[int] = get_host_parameter_limit() async def close(self) -> None: return await self._connection.close() diff --git a/chia/util/transactioner.py b/chia/util/transactioner.py index 52c16d0194fe..e726c57a701a 100644 --- a/chia/util/transactioner.py +++ b/chia/util/transactioner.py @@ -4,28 +4,15 @@ import asyncio import contextlib -import secrets -import sqlite3 -import sys -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from dataclasses import dataclass, field -from datetime import datetime from pathlib import Path from types import TracebackType from typing import Any, Callable, Generic, Optional, Protocol, TextIO, TypeVar, Union -import aiosqlite import anyio from typing_extensions import Self, final -if aiosqlite.sqlite_version_info < (3, 32, 0): - SQLITE_MAX_VARIABLE_NUMBER = 900 -else: - SQLITE_MAX_VARIABLE_NUMBER = 32700 - -# integers in sqlite are limited by int64 -SQLITE_INT_MAX = 2**63 - 1 - class DBWrapperError(Exception): pass @@ -42,73 +29,6 @@ def __init__(self, obj: object) -> None: self.obj = obj -def generate_in_memory_db_uri() -> str: - # We need to use shared cache as our DB wrapper uses different types of connections - return f"file:db_{secrets.token_hex(16)}?mode=memory&cache=shared" - - -async def execute_fetchone( - c: aiosqlite.Connection, sql: str, parameters: Optional[Iterable[Any]] = None -) -> Optional[sqlite3.Row]: - rows = await c.execute_fetchall(sql, parameters) - for row in rows: - return row - return None - - -def sql_trace_callback(req: str, file: TextIO, name: Optional[str] = None) -> None: - timestamp = datetime.now().strftime("%H:%M:%S.%f") - if name is not None: - line = f"{timestamp} {name} {req}\n" - else: - line = f"{timestamp} {req}\n" - file.write(line) - - -def get_host_parameter_limit() -> int: - # NOTE: This does not account for dynamically adjusted limits since it makes a - # separate db and connection. If aiosqlite adds support we should use it. - if sys.version_info >= (3, 11): - connection = sqlite3.connect(":memory:") - - limit_number = sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER - host_parameter_limit = connection.getlimit(limit_number) - else: - # guessing based on defaults, seems you can't query - - # https://www.sqlite.org/changes.html#version_3_32_0 - # Increase the default upper bound on the number of parameters from 999 to 32766. - if sqlite3.sqlite_version_info >= (3, 32, 0): - host_parameter_limit = 32766 - else: - host_parameter_limit = 999 - return host_parameter_limit - - -# class CursorProtocol(Protocol): -# async def close(self) -> None: ... -# # async def __aenter__(self) -> Self: ... -# # async def __aexit__( -# # self, -# # exc_type: Optional[type[BaseException]], -# # exc_val: Optional[BaseException], -# # exc_tb: Optional[TracebackType], -# # ) -> None: ... -# -# -# TCursor_co = TypeVar("TCursor_co", bound=CursorProtocol, covariant=True) -# -# class AwaitableEnterable(Protocol[TCursor_co]): -# def __await__(self) -> TCursor_co: ... -# async def __aenter__(self) -> TCursor_co: ... -# async def __aexit__( -# self, -# exc_type: Optional[type[BaseException]], -# exc_val: Optional[BaseException], -# exc_tb: Optional[TracebackType], -# ) -> None: ... - - T_co = TypeVar("T_co", covariant=True) @@ -135,8 +55,6 @@ async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: yield -# a_cursor: CursorProtocol = cast(aiosqlite.Cursor, None) - # TODO: are these ok missing type parameters so they stay generic? or... TConnection = TypeVar("TConnection", bound=ConnectionProtocol) # type: ignore[type-arg] TConnection_co = TypeVar("TConnection_co", bound=ConnectionProtocol, covariant=True) # type: ignore[type-arg] @@ -179,7 +97,6 @@ class Transactioner(Generic[TConnection, TUntransactionedConnection]): _write_connection: TConnection db_version: int = 1 _log_file: Optional[TextIO] = None - host_parameter_limit: int = get_host_parameter_limit() _lock: asyncio.Lock = field(default_factory=asyncio.Lock) _read_connections: asyncio.Queue[TConnection] = field(default_factory=asyncio.Queue) _num_read_connections: int = 0 @@ -220,7 +137,6 @@ async def writer_no_transaction(self) -> AsyncIterator[TUntransactionedConnectio exception) this transaction will be rolled back, but the next outer transaction is not necessarily cancelled. It would also need to exit with an exception to be cancelled. - The sqlite features this relies on are SAVEPOINT, ROLLBACK TO and RELEASE. """ task = asyncio.current_task() assert task is not None @@ -248,7 +164,6 @@ async def writer(self) -> AsyncIterator[TConnection]: exception) this transaction will be rolled back, but the next outer transaction is not necessarily cancelled. It would also need to exit with an exception to be cancelled. - The sqlite features this relies on are SAVEPOINT, ROLLBACK TO and RELEASE. """ task = asyncio.current_task() assert task is not None From 4f11676b41ffad31b2b99d0eee498fae802609ee Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 4 Aug 2025 12:41:40 -0400 Subject: [PATCH 6/6] rename to `.writer_outside_transaction()` --- chia/_tests/db/test_sqlite_wrapper.py | 8 ++++---- chia/data_layer/data_store.py | 6 +++--- chia/util/transactioner.py | 14 ++++++-------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/chia/_tests/db/test_sqlite_wrapper.py b/chia/_tests/db/test_sqlite_wrapper.py index b21f0e5324ee..8434463a7c7f 100644 --- a/chia/_tests/db/test_sqlite_wrapper.py +++ b/chia/_tests/db/test_sqlite_wrapper.py @@ -477,7 +477,7 @@ async def test_cancelled_reader_does_not_cancel_writer() -> None: @pytest.mark.anyio async def test_foreign_key_pragma_controlled_by_writer(initial: bool, forced: bool) -> None: async with DBConnection(2, foreign_keys=initial) as db_wrapper: - async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with db_wrapper.writer() as writer: async with writer.execute("PRAGMA foreign_keys") as cursor: @@ -518,7 +518,7 @@ async def test_foreign_key_pragma_rolls_back_on_foreign_key_error() -> None: # make sure the writer raises a foreign key error on exit with pytest.raises(ForeignKeyError): - async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with db_wrapper.writer() as writer: async with writer.execute("DELETE FROM people WHERE id = 1"): @@ -579,7 +579,7 @@ async def test_foreign_key_check_failure_error_message(case: RowFactoryCase) -> # make sure the writer raises a foreign key error on exit with pytest.raises(ForeignKeyError) as error: - async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with db_wrapper.writer() as writer: async with writer.execute("DELETE FROM people WHERE id = 1"): @@ -606,7 +606,7 @@ async def test_delayed_foreign_key_request_fails_when_nested(initial: bool) -> N async with DBConnection(2, foreign_keys=initial) as db_wrapper: async with db_wrapper.writer(): with pytest.raises(NestedForeignKeyDelayedRequestError): - async with db_wrapper.writer_no_transaction() as writer_no_transaction: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with db_wrapper.writer(): pass # pragma: no cover diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index aa196c276e02..174589722025 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -199,7 +199,7 @@ async def migrate_db(self) -> None: version = "v1.0" log.info(f"Initiating migration to version {version}") - async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with self.db_wrapper.writer() as writer: await writer.execute( @@ -1432,7 +1432,7 @@ async def clean_node_table(self, writer: Optional[SqliteConnection] = None) -> N """ params = {"pending_status": Status.PENDING.value, "pending_batch_status": Status.PENDING_BATCH.value} if writer is None: - async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with self.db_wrapper.writer() as writer: await writer.execute(query, params) @@ -2071,7 +2071,7 @@ async def remove_subscriptions(self, store_id: bytes32, urls: list[str]) -> None ) async def delete_store_data(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer_no_transaction() as writer_no_transaction: + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): async with self.db_wrapper.writer() as writer: await self.clean_node_table(writer) diff --git a/chia/util/transactioner.py b/chia/util/transactioner.py index e726c57a701a..e39bbf45ee61 100644 --- a/chia/util/transactioner.py +++ b/chia/util/transactioner.py @@ -128,15 +128,13 @@ def _next_savepoint(self) -> str: return name @contextlib.asynccontextmanager - async def writer_no_transaction(self) -> AsyncIterator[TUntransactionedConnection]: + async def writer_outside_transaction(self) -> AsyncIterator[TUntransactionedConnection]: """ - Initiates a new, possibly nested, transaction. If this task is already - in a transaction, none of the changes made as part of this transaction - will become visible to others until that top level transaction commits. - If this transaction fails (by exiting the context manager with an - exception) this transaction will be rolled back, but the next outer - transaction is not necessarily cancelled. It would also need to exit - with an exception to be cancelled. + Provides a connection without any active transaction. These connection + objects are generally made to be very limited. An Sqlite specific example + usage is to execute pragmas related to controlling foreign key enforcement + which must be executed outside of a transaction. If this task is already + in a transaction, an error is raised immediately. """ task = asyncio.current_task() assert task is not None