diff --git a/chia/_tests/db/test_sqlite_wrapper.py b/chia/_tests/db/test_sqlite_wrapper.py new file mode 100644 index 000000000000..8434463a7c7f --- /dev/null +++ b/chia/_tests/db/test_sqlite_wrapper.py @@ -0,0 +1,612 @@ +from __future__ import annotations + +import asyncio +import contextlib +import tempfile +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional + +import aiosqlite +import pytest + +# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469 +from _pytest.fixtures import SubRequest + +from chia._tests.util.misc import Marks, boolean_datacases, datacases +from chia.util import sqlite_wrapper +from chia.util.sqlite_wrapper import ( + ForeignKeyError, + NestedForeignKeyDelayedRequestError, + SqliteConnection, + SqliteTransactioner, + generate_in_memory_db_uri, +) +from chia.util.task_referencer import create_referenced_task +from chia.util.transactioner import InternalError + +DBWrapper2 = SqliteTransactioner + + +@asynccontextmanager +async def DBConnection( + db_version: int, + foreign_keys: Optional[bool] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> AsyncIterator[DBWrapper2]: + db_uri = generate_in_memory_db_uri() + async with sqlite_wrapper.managed( + database=db_uri, + uri=True, + reader_count=4, + db_version=db_version, + foreign_keys=foreign_keys, + row_factory=row_factory, + ) as _db_wrapper: + yield _db_wrapper + + +@asynccontextmanager +async def PathDBConnection(db_version: int) -> AsyncIterator[DBWrapper2]: + with tempfile.TemporaryDirectory() as directory: + db_path = Path(directory).joinpath("db.sqlite") + async with sqlite_wrapper.managed(database=db_path, reader_count=4, db_version=db_version) as _db_wrapper: + yield _db_wrapper + + +if TYPE_CHECKING: + ConnectionContextManager = contextlib.AbstractAsyncContextManager[SqliteConnection] + GetReaderMethod = Callable[[DBWrapper2], Callable[[], ConnectionContextManager]] + + +class UniqueError(Exception): + """Used to uniquely trigger the exception path out of the context managers.""" + + +async def increment_counter(db_wrapper: DBWrapper2) -> None: + async with db_wrapper.writer_maybe_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [old_value] = row + + await asyncio.sleep(0) + + new_value = old_value + 1 + await connection.execute("UPDATE counter SET value = :value", {"value": new_value}) + + +async def decrement_counter(db_wrapper: DBWrapper2) -> None: + async with db_wrapper.writer_maybe_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [old_value] = row + + await asyncio.sleep(0) + + new_value = old_value - 1 + await connection.execute("UPDATE counter SET value = :value", {"value": new_value}) + + +async def sum_counter(db_wrapper: DBWrapper2, output: list[int]) -> None: + async with db_wrapper.reader_no_transaction() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [value] = row + + output.append(value) + + +async def setup_table(db: DBWrapper2) -> None: + async with db.writer_maybe_transaction() as conn: + await conn.execute("CREATE TABLE counter(value INTEGER NOT NULL)") + await conn.execute("INSERT INTO counter(value) VALUES(0)") + + +async def get_value(cursor: aiosqlite.Cursor) -> int: + row = await cursor.fetchone() + assert row + return int(row[0]) + + +async def query_value(connection: SqliteConnection) -> int: + async with connection.execute("SELECT value FROM counter") as cursor: + return await get_value(cursor=cursor) + + +def _get_reader_no_transaction_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]: + return db_wrapper.reader_no_transaction + + +def _get_regular_reader_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]: + return db_wrapper.reader + + +@pytest.fixture( + name="get_reader_method", + params=[ + pytest.param(_get_reader_no_transaction_method, id="reader_no_transaction"), + pytest.param(_get_regular_reader_method, id="reader"), + ], +) +def get_reader_method_fixture(request: SubRequest) -> Callable[[], ConnectionContextManager]: + # https://github.com/pytest-dev/pytest/issues/8763 + return request.param # type: ignore[no-any-return] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(db_wrapper.writer_maybe_transaction()) + + tasks = [] + for index in range(concurrent_task_count): + task = create_referenced_task(increment_counter(db_wrapper)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + async with get_reader_method(db_wrapper)() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + + assert row is not None + [value] = row + + assert value == concurrent_task_count + + +@pytest.mark.anyio +async def test_writers_nests() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + async with db_wrapper.writer_maybe_transaction() as conn1: + async with conn1.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + async with db_wrapper.writer_maybe_transaction() as conn2: + assert conn1 == conn2 + value += 1 + await conn2.execute("UPDATE counter SET value = :value", {"value": value}) + async with db_wrapper.writer_maybe_transaction() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 1 + + +@pytest.mark.anyio +async def test_writer_journal_mode_wal() -> None: + async with PathDBConnection(2) as db_wrapper: + async with db_wrapper.writer() as connection: + async with connection.execute("PRAGMA journal_mode") as cursor: + result = await cursor.fetchone() + assert result == ("wal",) + + +@pytest.mark.anyio +async def test_reader_journal_mode_wal() -> None: + async with PathDBConnection(2) as db_wrapper: + async with db_wrapper.reader_no_transaction() as connection: + async with connection.execute("PRAGMA journal_mode") as cursor: + result = await cursor.fetchone() + assert result == ("wal",) + + +@pytest.mark.anyio +async def test_partial_failure() -> None: + values = [] + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + async with db_wrapper.writer() as conn1: + await conn1.execute("UPDATE counter SET value = 42") + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + try: + async with db_wrapper.writer() as conn2: + await conn2.execute("UPDATE counter SET value = 1337") + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + # this simulates a failure, which will cause a rollback of the + # write we just made, back to 42 + raise RuntimeError("failure within a sub-transaction") + except RuntimeError: + # we expect to get here + values.append(1) + async with conn1.execute("SELECT value FROM counter") as cursor: + values.append(await get_value(cursor)) + + # the write of 1337 failed, and was restored to 42 + assert values == [42, 1337, 1, 42] + + +@pytest.mark.anyio +async def test_readers_nests(get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with get_reader_method(db_wrapper)() as conn1: + async with get_reader_method(db_wrapper)() as conn2: + assert conn1 == conn2 + async with get_reader_method(db_wrapper)() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 0 + + +@pytest.mark.anyio +async def test_readers_nests_writer(get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as conn1: + async with get_reader_method(db_wrapper)() as conn2: + assert conn1 == conn2 + async with db_wrapper.writer_maybe_transaction() as conn3: + assert conn1 == conn3 + async with conn3.execute("SELECT value FROM counter") as cursor: + value = await get_value(cursor) + + assert value == 0 + + +@pytest.mark.parametrize( + argnames="transactioned", + argvalues=[ + pytest.param(True, id="transaction"), + pytest.param(False, id="no transaction"), + ], +) +@pytest.mark.anyio +async def test_only_transactioned_reader_ignores_writer(transactioned: bool) -> None: + writer_committed = asyncio.Event() + reader_read = asyncio.Event() + + async def write() -> None: + try: + async with db_wrapper.writer() as writer: + assert reader is not writer + + await writer.execute("UPDATE counter SET value = 1") + finally: + writer_committed.set() + + await reader_read.wait() + + assert await query_value(connection=writer) == 1 + + async with PathDBConnection(2) as db_wrapper: + get_reader = db_wrapper.reader if transactioned else db_wrapper.reader_no_transaction + + await setup_table(db_wrapper) + + async with get_reader() as reader: + assert await query_value(connection=reader) == 0 + + task = create_referenced_task(write()) + await writer_committed.wait() + + assert await query_value(connection=reader) == 0 if transactioned else 1 + reader_read.set() + + await task + + async with get_reader() as reader: + assert await query_value(connection=reader) == 1 + + +@pytest.mark.anyio +async def test_reader_nests_and_ends_transaction() -> None: + async with DBConnection(2) as db_wrapper: + async with db_wrapper.reader() as reader: + assert reader.in_transaction + + async with db_wrapper.reader() as inner_reader: + assert inner_reader is reader + assert reader.in_transaction + + assert reader.in_transaction + + assert not reader.in_transaction + + +@pytest.mark.anyio +async def test_writer_in_reader_works() -> None: + async with PathDBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.reader() as reader: + async with db_wrapper.writer() as writer: + assert writer is not reader + await writer.execute("UPDATE counter SET value = 1") + assert await query_value(connection=writer) == 1 + assert await query_value(connection=reader) == 0 + + assert await query_value(connection=reader) == 0 + + +@pytest.mark.anyio +async def test_reader_transaction_is_deferred() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.reader() as reader: + async with db_wrapper.writer() as writer: + assert writer is not reader + await writer.execute("UPDATE counter SET value = 1") + assert await query_value(connection=writer) == 1 + + # The deferred transaction initiation results in the transaction starting + # here and thus reading the written value. + assert await query_value(connection=reader) == 1 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as connection: + await connection.execute("UPDATE counter SET value = 1") + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(get_reader_method(db_wrapper)()) + + tasks = [] + values: list[int] = [] + for index in range(concurrent_task_count): + task = create_referenced_task(sum_counter(db_wrapper, values)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + assert values == [1] * concurrent_task_count + + +@pytest.mark.anyio +@pytest.mark.parametrize( + argnames="acquire_outside", + argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")], +) +async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None: + async with PathDBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer_maybe_transaction() as connection: + await connection.execute("UPDATE counter SET value = 1") + + concurrent_task_count = 200 + + async with contextlib.AsyncExitStack() as exit_stack: + if acquire_outside: + await exit_stack.enter_async_context(get_reader_method(db_wrapper)()) + + tasks = [] + values: list[int] = [] + for index in range(concurrent_task_count): + task = create_referenced_task(increment_counter(db_wrapper)) + tasks.append(task) + task = create_referenced_task(decrement_counter(db_wrapper)) + tasks.append(task) + task = create_referenced_task(sum_counter(db_wrapper, values)) + tasks.append(task) + + await asyncio.wait_for(asyncio.gather(*tasks), timeout=None) + + # we increment and decrement the counter an equal number of times. It should + # end back at 1. + async with get_reader_method(db_wrapper)() as connection: + async with connection.execute("SELECT value FROM counter") as cursor: + row = await cursor.fetchone() + assert row is not None + assert row[0] == 1 + + # it's possible all increments or all decrements are run first + assert len(values) == concurrent_task_count + for v in values: + assert v > -99 + assert v <= 100 + + +@pytest.mark.parametrize( + argnames=["manager_method", "expected"], + argvalues=[ + [DBWrapper2.writer, True], + [DBWrapper2.writer_maybe_transaction, True], + [DBWrapper2.reader, True], + [DBWrapper2.reader_no_transaction, False], + ], +) +@pytest.mark.anyio +async def test_in_transaction_as_expected( + manager_method: Callable[[DBWrapper2], ConnectionContextManager], + expected: bool, +) -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with manager_method(db_wrapper) as connection: + assert connection.in_transaction == expected + + +@pytest.mark.anyio +async def test_cancelled_reader_does_not_cancel_writer() -> None: + async with DBConnection(2) as db_wrapper: + await setup_table(db_wrapper) + + async with db_wrapper.writer() as writer: + await writer.execute("UPDATE counter SET value = 1") + + with pytest.raises(UniqueError): + async with db_wrapper.reader() as _: + raise UniqueError + + assert await query_value(connection=writer) == 1 + + assert await query_value(connection=writer) == 1 + + +@boolean_datacases(name="initial", false="initially disabled", true="initially enabled") +@boolean_datacases(name="forced", false="forced disabled", true="forced enabled") +@pytest.mark.anyio +async def test_foreign_key_pragma_controlled_by_writer(initial: bool, forced: bool) -> None: + async with DBConnection(2, foreign_keys=initial) as db_wrapper: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + assert result is not None + [actual] = result + + assert actual == (1 if forced else 0) + + +@pytest.mark.anyio +async def test_foreign_key_pragma_rolls_back_on_foreign_key_error() -> None: + async with DBConnection(2, foreign_keys=True, row_factory=aiosqlite.Row) as db_wrapper: + async with db_wrapper.writer() as writer: + async with writer.execute( + """ + CREATE TABLE people( + id INTEGER NOT NULL, + friend INTEGER, + PRIMARY KEY (id), + FOREIGN KEY (friend) REFERENCES people + ) + """ + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 1, "friend": None}, + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 2, "friend": 1}, + ): + pass + + # make sure the writer raises a foreign key error on exit + with pytest.raises(ForeignKeyError): + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass + + # make sure a foreign key error can be detected here + with pytest.raises(ForeignKeyError): + await writer._check_foreign_keys() + + async with writer.execute("SELECT * FROM people WHERE id = 1") as cursor: + [person] = await cursor.fetchall() + + # make sure the delete was rolled back + assert dict(person) == {"id": 1, "friend": None} + + +@dataclass +class RowFactoryCase: + id: str + factory: Optional[type[aiosqlite.Row]] + marks: Marks = () + + +row_factory_cases: list[RowFactoryCase] = [ + RowFactoryCase(id="default named tuple", factory=None), + RowFactoryCase(id="aiosqlite row", factory=aiosqlite.Row), +] + + +@datacases(*row_factory_cases) +@pytest.mark.anyio +async def test_foreign_key_check_failure_error_message(case: RowFactoryCase) -> None: + async with DBConnection(2, foreign_keys=True, row_factory=case.factory) as db_wrapper: + async with db_wrapper.writer() as writer: + async with writer.execute( + """ + CREATE TABLE people( + id INTEGER NOT NULL, + friend INTEGER, + PRIMARY KEY (id), + FOREIGN KEY (friend) REFERENCES people + ) + """ + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 1, "friend": None}, + ): + pass + + async with writer.execute( + "INSERT INTO people(id, friend) VALUES (:id, :friend)", + {"id": 2, "friend": 1}, + ): + pass + + # make sure the writer raises a foreign key error on exit + with pytest.raises(ForeignKeyError) as error: + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer() as writer: + async with writer.execute("DELETE FROM people WHERE id = 1"): + pass + + assert error.value.violations == [{"table": "people", "rowid": 2, "parent": "people", "fkid": 0}] + + +@pytest.mark.anyio +async def test_set_foreign_keys_fails_within_acquired_writer() -> None: + async with DBConnection(2, foreign_keys=True) as db_wrapper: + async with db_wrapper.writer() as writer: + with pytest.raises( + InternalError, + match="Unable to set foreign key enforcement state while a writer is held", + ): + async with writer._set_foreign_key_enforcement(enabled=False): + pass # pragma: no cover + + +@boolean_datacases(name="initial", false="initially disabled", true="initially enabled") +@pytest.mark.anyio +async def test_delayed_foreign_key_request_fails_when_nested(initial: bool) -> None: + async with DBConnection(2, foreign_keys=initial) as db_wrapper: + async with db_wrapper.writer(): + with pytest.raises(NestedForeignKeyDelayedRequestError): + async with db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with db_wrapper.writer(): + pass # pragma: no cover diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index 517b763517a0..14ce9a55ad29 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -14,8 +14,8 @@ from chia.data_layer.data_layer_errors import ProofIntegrityError from chia.server.ws_connection import WSChiaConnection from chia.types.blockchain_format.program import Program +from chia.util import sqlite_wrapper from chia.util.byte_types import hexstr_to_bytes -from chia.util.db_wrapper import DBWrapper2 from chia.util.streamable import Streamable, streamable from chia.wallet.db_wallet.db_wallet_puzzles import create_host_fullpuz @@ -77,7 +77,7 @@ def get_hashes_for_page(page: int, lengths: dict[bytes32, int], max_page_size: i return PaginationData(current_page + 1, total_bytes, hashes) -async def _debug_dump(db: DBWrapper2, description: str = "") -> None: +async def _debug_dump(db: sqlite_wrapper.SqliteTransactioner, description: str = "") -> None: async with db.reader() as reader: cursor = await reader.execute("SELECT name FROM sqlite_master WHERE type='table';") print("-" * 50, description, flush=True) diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 039b68973693..174589722025 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -42,7 +42,9 @@ unspecified, ) from chia.types.blockchain_format.program import Program -from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2 +from chia.util import sqlite_wrapper +from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER +from chia.util.sqlite_wrapper import SqliteConnection log = logging.getLogger(__name__) @@ -55,14 +57,14 @@ class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" - db_wrapper: DBWrapper2 + db_wrapper: sqlite_wrapper.SqliteTransactioner @classmethod @contextlib.asynccontextmanager async def managed( cls, database: Union[str, Path], uri: bool = False, sql_log_path: Optional[Path] = None ) -> AsyncIterator[DataStore]: - async with DBWrapper2.managed( + async with sqlite_wrapper.managed( database=database, uri=uri, journal_mode="WAL", @@ -197,25 +199,27 @@ async def migrate_db(self) -> None: version = "v1.0" log.info(f"Initiating migration to version {version}") - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await writer.execute( - f""" - CREATE TABLE IF NOT EXISTS new_root( - tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), - generation INTEGER NOT NULL CHECK(generation >= 0), - node_hash BLOB, - status INTEGER NOT NULL CHECK( - {" OR ".join(f"status == {status}" for status in Status)} - ), - PRIMARY KEY(tree_id, generation), - FOREIGN KEY(node_hash) REFERENCES node(hash) - ) - """ - ) - await writer.execute("INSERT INTO new_root SELECT * FROM root") - await writer.execute("DROP TABLE root") - await writer.execute("ALTER TABLE new_root RENAME TO root") - await writer.execute("INSERT INTO schema (version_id) VALUES (?)", (version,)) + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await writer.execute( + f""" + CREATE TABLE IF NOT EXISTS new_root( + tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), + generation INTEGER NOT NULL CHECK(generation >= 0), + node_hash BLOB, + status INTEGER NOT NULL CHECK( + {" OR ".join(f"status == {status}" for status in Status)} + ), + PRIMARY KEY(tree_id, generation), + FOREIGN KEY(node_hash) REFERENCES node(hash) + ) + """ + ) + await writer.execute("INSERT INTO new_root SELECT * FROM root") + await writer.execute("DROP TABLE root") + await writer.execute("ALTER TABLE new_root RENAME TO root") + await writer.execute("INSERT INTO schema (version_id) VALUES (?)", (version,)) log.info(f"Finished migrating DB to version {version}") async def _insert_root( @@ -756,7 +760,7 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3 async def get_keys_values_cursor( self, - reader: aiosqlite.Connection, + reader: SqliteConnection, root_hash: Optional[bytes32], only_keys: bool = False, ) -> aiosqlite.Cursor: @@ -1404,7 +1408,7 @@ async def upsert( return InsertResult(node_hash=new_terminal_node_hash, root=new_root) - async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None) -> None: + async def clean_node_table(self, writer: Optional[SqliteConnection] = None) -> None: query = """ WITH RECURSIVE pending_nodes AS ( SELECT node_hash AS hash FROM root @@ -1428,8 +1432,10 @@ async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None) """ params = {"pending_status": Status.PENDING.value, "pending_batch_status": Status.PENDING_BATCH.value} if writer is None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await writer.execute(query, params) + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await writer.execute(query, params) else: await writer.execute(query, params) @@ -2065,76 +2071,78 @@ async def remove_subscriptions(self, store_id: bytes32, urls: list[str]) -> None ) async def delete_store_data(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await self.clean_node_table(writer) - cursor = await writer.execute( - """ - WITH RECURSIVE all_nodes AS ( - SELECT a.hash, n.left, n.right - FROM ancestors AS a - JOIN node AS n ON a.hash = n.hash - WHERE a.tree_id = :tree_id - ), - pending_nodes AS ( - SELECT node_hash AS hash FROM root - WHERE status IN (:pending_status, :pending_batch_status) - UNION ALL - SELECT n.left FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.left IS NOT NULL - UNION ALL - SELECT n.right FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.right IS NOT NULL - ) + async with self.db_wrapper.writer_outside_transaction() as writer_no_transaction: + async with writer_no_transaction.delay(foreign_key_enforcement_enabled=False): + async with self.db_wrapper.writer() as writer: + await self.clean_node_table(writer) + cursor = await writer.execute( + """ + WITH RECURSIVE all_nodes AS ( + SELECT a.hash, n.left, n.right + FROM ancestors AS a + JOIN node AS n ON a.hash = n.hash + WHERE a.tree_id = :tree_id + ), + pending_nodes AS ( + SELECT node_hash AS hash FROM root + WHERE status IN (:pending_status, :pending_batch_status) + UNION ALL + SELECT n.left FROM node n + INNER JOIN pending_nodes pn ON n.hash = pn.hash + WHERE n.left IS NOT NULL + UNION ALL + SELECT n.right FROM node n + INNER JOIN pending_nodes pn ON n.hash = pn.hash + WHERE n.right IS NOT NULL + ) - SELECT hash, left, right - FROM all_nodes - WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id) - AND hash NOT IN (SELECT hash from pending_nodes) - """, - { - "tree_id": store_id, - "pending_status": Status.PENDING.value, - "pending_batch_status": Status.PENDING_BATCH.value, - }, - ) - to_delete: dict[bytes, tuple[bytes, bytes]] = {} - ref_counts: dict[bytes, int] = {} - async for row in cursor: - hash = row["hash"] - left = row["left"] - right = row["right"] - if hash in to_delete: - prev_left, prev_right = to_delete[hash] - assert prev_left == left - assert prev_right == right - continue - to_delete[hash] = (left, right) - if left is not None: - ref_counts[left] = ref_counts.get(left, 0) + 1 - if right is not None: - ref_counts[right] = ref_counts.get(right, 0) + 1 - - await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (store_id,)) - await writer.execute("DELETE FROM root WHERE tree_id == ?", (store_id,)) - queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0] - while queue: - hash = queue.pop(0) - if hash not in to_delete: - continue - await writer.execute("DELETE FROM node WHERE hash == ?", (hash,)) - - left, right = to_delete[hash] - if left is not None: - ref_counts[left] -= 1 - if ref_counts[left] == 0: - queue.append(left) - - if right is not None: - ref_counts[right] -= 1 - if ref_counts[right] == 0: - queue.append(right) + SELECT hash, left, right + FROM all_nodes + WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id) + AND hash NOT IN (SELECT hash from pending_nodes) + """, + { + "tree_id": store_id, + "pending_status": Status.PENDING.value, + "pending_batch_status": Status.PENDING_BATCH.value, + }, + ) + to_delete: dict[bytes, tuple[bytes, bytes]] = {} + ref_counts: dict[bytes, int] = {} + async for row in cursor: + hash = row["hash"] + left = row["left"] + right = row["right"] + if hash in to_delete: + prev_left, prev_right = to_delete[hash] + assert prev_left == left + assert prev_right == right + continue + to_delete[hash] = (left, right) + if left is not None: + ref_counts[left] = ref_counts.get(left, 0) + 1 + if right is not None: + ref_counts[right] = ref_counts.get(right, 0) + 1 + + await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (store_id,)) + await writer.execute("DELETE FROM root WHERE tree_id == ?", (store_id,)) + queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0] + while queue: + hash = queue.pop(0) + if hash not in to_delete: + continue + await writer.execute("DELETE FROM node WHERE hash == ?", (hash,)) + + left, right = to_delete[hash] + if left is not None: + ref_counts[left] -= 1 + if ref_counts[left] == 0: + queue.append(left) + + if right is not None: + ref_counts[right] -= 1 + if ref_counts[right] == 0: + queue.append(right) async def unsubscribe(self, store_id: bytes32) -> None: async with self.db_wrapper.writer() as writer: diff --git a/chia/util/sqlite_wrapper.py b/chia/util/sqlite_wrapper.py new file mode 100644 index 000000000000..9d17b5d95a47 --- /dev/null +++ b/chia/util/sqlite_wrapper.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import contextlib +import functools +import secrets +import sqlite3 +import sys +from collections.abc import AsyncIterator, Iterable +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Optional, TextIO, TypeAlias, Union, cast + +import aiosqlite +import aiosqlite.context +import anyio +from aiosqlite import Cursor +from typing_extensions import Self + +from chia.util.transactioner import ( + ConnectionProtocol, + CreateConnectionCallable, + InternalError, + Transactioner, + manage_connection, +) + +SqliteTransactioner: TypeAlias = Transactioner["SqliteConnection", "UntransactionedSqliteConnection"] + +# if TYPE_CHECKING: +# _protocol_check: AwaitableEnterable[Cursor] = cast(aiosqlite.Cursor, None) + + +if aiosqlite.sqlite_version_info < (3, 32, 0): + SQLITE_MAX_VARIABLE_NUMBER = 900 +else: + SQLITE_MAX_VARIABLE_NUMBER = 32700 + +# integers in sqlite are limited by int64 +SQLITE_INT_MAX = 2**63 - 1 + + +def generate_in_memory_db_uri() -> str: + # We need to use shared cache as our DB wrapper uses different types of connections + return f"file:db_{secrets.token_hex(16)}?mode=memory&cache=shared" + + +# async def execute_fetchone( +# c: aiosqlite.Connection, sql: str, parameters: Optional[Iterable[Any]] = None +# ) -> Optional[sqlite3.Row]: +# rows = await c.execute_fetchall(sql, parameters) +# for row in rows: +# return row +# return None + + +def sql_trace_callback(req: str, file: TextIO, name: Optional[str] = None) -> None: + timestamp = datetime.now().strftime("%H:%M:%S.%f") + if name is not None: + line = f"{timestamp} {name} {req}\n" + else: + line = f"{timestamp} {req}\n" + file.write(line) + + +def get_host_parameter_limit() -> int: + # NOTE: This does not account for dynamically adjusted limits since it makes a + # separate db and connection. If aiosqlite adds support we should use it. + if sys.version_info >= (3, 11): + connection = sqlite3.connect(":memory:") + + limit_number = sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER + host_parameter_limit = connection.getlimit(limit_number) + else: + # guessing based on defaults, seems you can't query + + # https://www.sqlite.org/changes.html#version_3_32_0 + # Increase the default upper bound on the number of parameters from 999 to 32766. + if sqlite3.sqlite_version_info >= (3, 32, 0): + host_parameter_limit = 32766 + else: + host_parameter_limit = 999 + return host_parameter_limit + + +# TODO: think about the inheritance +# class NestedForeignKeyDelayedRequestError(DBWrapperError): +class NestedForeignKeyDelayedRequestError(Exception): + def __init__(self) -> None: + super().__init__("Unable to enable delayed foreign key enforcement in a nested request.") + + +# TODO: think about the inheritance +# class ForeignKeyError(DBWrapperError): +class ForeignKeyError(Exception): + def __init__(self, violations: Iterable[Union[aiosqlite.Row, tuple[str, object, str, object]]]) -> None: + self.violations: list[dict[str, object]] = [] + + for violation in violations: + if isinstance(violation, tuple): + violation_dict = dict(zip(["table", "rowid", "parent", "fkid"], violation)) + else: + violation_dict = dict(violation) + self.violations.append(violation_dict) + + super().__init__(f"Found {len(self.violations)} FK violations: {self.violations}") + + +@dataclass +class UntransactionedSqliteConnection: + _connection: SqliteConnection + + @contextlib.asynccontextmanager + async def delay(self, *, foreign_key_enforcement_enabled: bool) -> AsyncIterator[None]: + if self._connection.in_transaction and foreign_key_enforcement_enabled is not None: + # NOTE: Technically this is complaining even if the requested state is + # already in place. This could be adjusted to allow nesting + # when the existing and requested states agree. In this case, + # probably skip the nested foreign key check when exiting since + # we don't have many foreign key errors and so it is likely ok + # to save the extra time checking twice. + raise NestedForeignKeyDelayedRequestError + + async with self._connection._set_foreign_key_enforcement(enabled=foreign_key_enforcement_enabled): + async with self._connection.savepoint_ctx("delay"): + try: + yield + finally: + await self._connection._check_foreign_keys() + + +@dataclass +class SqliteConnection: + if TYPE_CHECKING: + _protocol_check: ClassVar[ConnectionProtocol[aiosqlite.context.Result[aiosqlite.Cursor]]] = cast( + "SqliteConnection", None + ) + + _connection: aiosqlite.Connection + host_parameter_limit: ClassVar[int] = get_host_parameter_limit() + + async def close(self) -> None: + return await self._connection.close() + + def execute(self, *args: Any, **kwargs: Any) -> aiosqlite.context.Result[Cursor]: + return self._connection.execute(*args, **kwargs) + + @property + def in_transaction(self) -> bool: + return self._connection.in_transaction + + async def rollback(self) -> None: + await self._connection.rollback() + + async def configure_as_reader(self) -> None: + await self.execute("pragma query_only") + + async def read_transaction(self) -> None: + await self.execute("BEGIN DEFERRED;") + + @contextlib.asynccontextmanager + async def _set_foreign_key_enforcement(self, enabled: bool) -> AsyncIterator[None]: + if self.in_transaction: + raise InternalError("Unable to set foreign key enforcement state while a writer is held") + + async with self._connection.execute("PRAGMA foreign_keys") as cursor: + result = await cursor.fetchone() + if result is None: # pragma: no cover + raise InternalError("No results when querying for present foreign key enforcement state") + [original_value] = result + + if original_value == enabled: + yield + return + + try: + await self._connection.execute(f"PRAGMA foreign_keys={enabled}") + yield + finally: + with anyio.CancelScope(shield=True): + await self._connection.execute(f"PRAGMA foreign_keys={original_value}") + + async def _check_foreign_keys(self) -> None: + async with self._connection.execute("PRAGMA foreign_key_check") as cursor: + violations = list(await cursor.fetchall()) + + if len(violations) > 0: + raise ForeignKeyError(violations=violations) + + async def __aenter__(self) -> Self: + await self._connection.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return await self._connection.__aexit__(exc_type, exc_val, exc_tb) + + @contextlib.asynccontextmanager + async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: + await self._connection.execute(f"SAVEPOINT {name}") + try: + yield + except: + await self._connection.execute(f"ROLLBACK TO {name}") + raise + finally: + # rollback to a savepoint doesn't cancel the transaction, it + # just rolls back the state. We need to cancel it regardless + await self._connection.execute(f"RELEASE {name}") + + +async def sqlite_create_connection( + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> SqliteConnection: + # To avoid https://github.com/python/cpython/issues/118172 + connection = await aiosqlite.connect(database=database, uri=uri, cached_statements=0) + + if log_file is not None: + await connection.set_trace_callback(functools.partial(sql_trace_callback, file=log_file, name=name)) + + if row_factory is not None: + connection.row_factory = row_factory + + return SqliteConnection(_connection=connection) + + +@contextlib.asynccontextmanager +async def managed( + database: Union[str, Path], + *, + db_version: int = 1, + uri: bool = False, + reader_count: int = 4, + log_path: Optional[Path] = None, + journal_mode: str = "WAL", + synchronous: Optional[str] = None, + foreign_keys: Optional[bool] = None, + row_factory: Optional[type[aiosqlite.Row]] = None, +) -> AsyncIterator[SqliteTransactioner]: + if foreign_keys is None: + foreign_keys = False + + async with contextlib.AsyncExitStack() as async_exit_stack: + if log_path is None: + log_file = None + else: + log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = async_exit_stack.enter_context(log_path.open("a", encoding="utf-8")) + + write_connection = await async_exit_stack.enter_async_context( + manage_connection( + create_connection=sqlite_create_connection, database=database, uri=uri, log_file=log_file, name="writer" + ), + ) + await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close() + if synchronous is not None: + await (await write_connection.execute(f"pragma synchronous={synchronous}")).close() + + await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close() + + write_connection._connection.row_factory = row_factory + + self = Transactioner( + create_connection=sqlite_create_connection, + create_untransactioned_connection=UntransactionedSqliteConnection, + _write_connection=write_connection, + db_version=db_version, + _log_file=log_file, + ) + + for index in range(reader_count): + read_connection = await async_exit_stack.enter_async_context( + manage_connection( + create_connection=sqlite_create_connection, + database=database, + uri=uri, + log_file=log_file, + name=f"reader-{index}", + ), + ) + read_connection._connection.row_factory = row_factory + + await self.add_connection(c=read_connection) + try: + yield self + finally: + with anyio.CancelScope(shield=True): + while self._num_read_connections > 0: + await self._read_connections.get() + self._num_read_connections -= 1 + + +async def create( + create_connection: CreateConnectionCallable[SqliteConnection], + database: Union[str, Path], + *, + db_version: int = 1, + uri: bool = False, + reader_count: int = 4, + log_path: Optional[Path] = None, + journal_mode: str = "WAL", + synchronous: Optional[str] = None, + foreign_keys: bool = False, +) -> SqliteTransactioner: + # WARNING: please use .managed() instead + if log_path is None: + log_file = None + else: + log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = log_path.open("a", encoding="utf-8") + write_connection = await create_connection(database=database, uri=uri, log_file=log_file, name="writer") + await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close() + if synchronous is not None: + await (await write_connection.execute(f"pragma synchronous={synchronous}")).close() + + await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close() + + self = Transactioner( + create_connection=create_connection, + create_untransactioned_connection=UntransactionedSqliteConnection, + _write_connection=write_connection, + db_version=db_version, + _log_file=log_file, + ) + + for index in range(reader_count): + read_connection = await create_connection( + database=database, + uri=uri, + log_file=log_file, + name=f"reader-{index}", + ) + await self.add_connection(c=read_connection) + + return self diff --git a/chia/util/transactioner.py b/chia/util/transactioner.py new file mode 100644 index 000000000000..e39bbf45ee61 --- /dev/null +++ b/chia/util/transactioner.py @@ -0,0 +1,251 @@ +# Package: utils + +from __future__ import annotations + +import asyncio +import contextlib +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from pathlib import Path +from types import TracebackType +from typing import Any, Callable, Generic, Optional, Protocol, TextIO, TypeVar, Union + +import anyio +from typing_extensions import Self, final + + +class DBWrapperError(Exception): + pass + + +class InternalError(DBWrapperError): + pass + + +class PurposefulAbort(DBWrapperError): + obj: object + + def __init__(self, obj: object) -> None: + self.obj = obj + + +T_co = TypeVar("T_co", covariant=True) + + +class ConnectionProtocol(Protocol[T_co]): + # TODO: this is presently matching aiosqlite.Connection, generalize + + async def close(self) -> None: ... + def execute(self, *args: Any, **kwargs: Any) -> T_co: ... + @property + def in_transaction(self) -> bool: ... + async def rollback(self) -> None: ... + async def configure_as_reader(self) -> None: ... + async def read_transaction(self) -> None: ... + async def __aenter__(self) -> Self: ... + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... + + @contextlib.asynccontextmanager + async def savepoint_ctx(self, name: str) -> AsyncIterator[None]: + yield + + +# TODO: are these ok missing type parameters so they stay generic? or... +TConnection = TypeVar("TConnection", bound=ConnectionProtocol) # type: ignore[type-arg] +TConnection_co = TypeVar("TConnection_co", bound=ConnectionProtocol, covariant=True) # type: ignore[type-arg] +TUntransactionedConnection = TypeVar("TUntransactionedConnection") + + +class CreateConnectionCallable(Protocol[TConnection_co]): + async def __call__( + self, + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, + ) -> TConnection_co: ... + + +@contextlib.asynccontextmanager +async def manage_connection( + create_connection: CreateConnectionCallable[TConnection_co], + database: Union[str, Path], + uri: bool = False, + log_file: Optional[TextIO] = None, + name: Optional[str] = None, +) -> AsyncIterator[TConnection_co]: + connection: TConnection_co + connection = await create_connection(database=database, uri=uri, log_file=log_file, name=name) + + try: + yield connection + finally: + with anyio.CancelScope(shield=True): + await connection.close() + + +@final +@dataclass +class Transactioner(Generic[TConnection, TUntransactionedConnection]): + create_connection: CreateConnectionCallable[TConnection] + create_untransactioned_connection: Callable[[TConnection], TUntransactionedConnection] + _write_connection: TConnection + db_version: int = 1 + _log_file: Optional[TextIO] = None + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) + _read_connections: asyncio.Queue[TConnection] = field(default_factory=asyncio.Queue) + _num_read_connections: int = 0 + _in_use: dict[asyncio.Task[object], TConnection] = field(default_factory=dict) + _current_writer: Optional[asyncio.Task[object]] = None + _savepoint_name: int = 0 + + async def add_connection(self, c: TConnection) -> None: + # this guarantees that reader connections can only be used for reading + assert c != self._write_connection + await c.configure_as_reader() + self._read_connections.put_nowait(c) + self._num_read_connections += 1 + + async def close(self) -> None: + # WARNING: please use .managed() instead + try: + while self._num_read_connections > 0: + await (await self._read_connections.get()).close() + self._num_read_connections -= 1 + await self._write_connection.close() + finally: + if self._log_file is not None: + self._log_file.close() + + def _next_savepoint(self) -> str: + name = f"s{self._savepoint_name}" + self._savepoint_name += 1 + return name + + @contextlib.asynccontextmanager + async def writer_outside_transaction(self) -> AsyncIterator[TUntransactionedConnection]: + """ + Provides a connection without any active transaction. These connection + objects are generally made to be very limited. An Sqlite specific example + usage is to execute pragmas related to controlling foreign key enforcement + which must be executed outside of a transaction. If this task is already + in a transaction, an error is raised immediately. + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + if self._write_connection.in_transaction: + raise Exception("can't nest for no transaction inside an active transaction") + + yield self.create_untransactioned_connection(self._write_connection) + return + + async with self._lock: + self._current_writer = task + try: + yield self.create_untransactioned_connection(self._write_connection) + finally: + self._current_writer = None + + @contextlib.asynccontextmanager + async def writer(self) -> AsyncIterator[TConnection]: + """ + Initiates a new, possibly nested, transaction. If this task is already + in a transaction, none of the changes made as part of this transaction + will become visible to others until that top level transaction commits. + If this transaction fails (by exiting the context manager with an + exception) this transaction will be rolled back, but the next outer + transaction is not necessarily cancelled. It would also need to exit + with an exception to be cancelled. + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + # we allow nesting writers within the same task + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + yield self._write_connection + return + + async with self._lock: + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + self._current_writer = task + try: + yield self._write_connection + finally: + self._current_writer = None + + @contextlib.asynccontextmanager + async def writer_maybe_transaction(self) -> AsyncIterator[TConnection]: + """ + Initiates a write to the database. If this task is already in a write + transaction with the DB, this is a no-op. Any changes made to the + database will be rolled up into the transaction we're already in. If the + current task is not already in a transaction, one will be created and + committed (or rolled back in the case of an exception). + """ + task = asyncio.current_task() + assert task is not None + if self._current_writer == task: + # just use the existing transaction + yield self._write_connection + return + + async with self._lock: + async with self._write_connection.savepoint_ctx(name=self._next_savepoint()): + self._current_writer = task + try: + yield self._write_connection + finally: + self._current_writer = None + + @contextlib.asynccontextmanager + async def reader(self) -> AsyncIterator[TConnection]: + async with self.reader_no_transaction() as connection: + if connection.in_transaction: + yield connection + else: + await connection.read_transaction() + try: + yield connection + finally: + # close the transaction with a rollback instead of commit just in + # case any modifications were submitted through this reader + await connection.rollback() + + @contextlib.asynccontextmanager + async def reader_no_transaction(self) -> AsyncIterator[TConnection]: + # there should have been read connections added + assert self._num_read_connections > 0 + + # we can have multiple concurrent readers, just pick a connection from + # the pool of readers. If they're all busy, we'll wait for one to free + # up. + task = asyncio.current_task() + assert task is not None + + # if this task currently holds the write lock, use the same connection, + # so it can read back updates it has made to its transaction, even + # though it hasn't been committed yet + if self._current_writer == task: + # we allow nesting reading while also having a writer connection + # open, within the same task + yield self._write_connection + return + + if task in self._in_use: + yield self._in_use[task] + else: + c = await self._read_connections.get() + try: + # record our connection in this dict to allow nested calls in + # the same task to use the same connection + self._in_use[task] = c + yield c + finally: + del self._in_use[task] + self._read_connections.put_nowait(c)