Skip to content

Commit 9a6a22d

Browse files
committed
Continue implementing tests. Extra nodes does not work for some reason
1 parent bcaa733 commit 9a6a22d

File tree

8 files changed

+315
-4
lines changed

8 files changed

+315
-4
lines changed
380 Bytes
Binary file not shown.

psqlpy_piccolo/engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ def get_savepoint_id(self: Self) -> int:
276276
self._savepoint_id += 1
277277
return self._savepoint_id
278278

279+
async def rollback_to(self, savepoint_name: str) -> None:
280+
"""Use to rollback to a savepoint just using the name."""
281+
await Savepoint(name=savepoint_name, transaction=self).rollback_to()
282+
279283
async def savepoint(self: Self, name: str | None = None) -> Savepoint:
280284
"""Create new savepoint.
281285
@@ -665,7 +669,7 @@ async def run_querystring(
665669

666670
if self.log_queries:
667671
self.print_query(query_id=query_id, query=querystring.__str__())
668-
672+
print(querystring)
669673
# If running inside a transaction:
670674
current_transaction = self.current_transaction.get()
671675
if current_transaction:

tests/conftest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,65 @@
1+
import asyncio
2+
import typing
3+
from unittest.mock import MagicMock
4+
15
import pytest
26

7+
from tests.test_apps.mega.tables import MegaTable, SmallTable
8+
from tests.test_apps.music.tables import (
9+
Band,
10+
Concert,
11+
Instrument,
12+
Manager,
13+
Poster,
14+
RecordingStudio,
15+
Shirt,
16+
Ticket,
17+
Venue,
18+
)
19+
20+
if typing.TYPE_CHECKING:
21+
from piccolo.table import Table
22+
323
pytestmark = [pytest.mark.anyio]
424

525

626
@pytest.fixture(scope="session", autouse=True)
727
def anyio_backend() -> str:
828
return "asyncio"
29+
30+
31+
@pytest.fixture(autouse=True)
32+
async def _clean_up() -> None:
33+
tables_to_clean: typing.Final[list[Table]] = [
34+
MegaTable,
35+
SmallTable,
36+
Manager,
37+
Band,
38+
Venue,
39+
Concert,
40+
Ticket,
41+
Poster,
42+
Shirt,
43+
RecordingStudio,
44+
Instrument,
45+
]
46+
for table_to_clean in tables_to_clean:
47+
await table_to_clean.delete(force=True)
48+
49+
50+
class AsyncMock(MagicMock):
51+
"""
52+
Async MagicMock for python 3.7+.
53+
54+
This is a workaround for the fact that MagicMock is not async compatible in
55+
Python 3.7.
56+
"""
57+
58+
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # noqa: ANN401
59+
super().__init__(*args, **kwargs)
60+
61+
# this makes asyncio.iscoroutinefunction(AsyncMock()) return True
62+
self._is_coroutine = asyncio.coroutines.iscoroutine
63+
64+
async def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # noqa: ANN401
65+
return super(AsyncMock, self).__call__(*args, **kwargs) # noqa: UP008

tests/test_apps/mega/piccolo_app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
from piccolo.conf.apps import AppConfig
44

5-
from tests.test_apps.mega.tables import MegaTable, SmallTable
6-
75
CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) # noqa: PTH120, PTH100
86

97

108
APP_CONFIG = AppConfig(
119
app_name="mega",
12-
table_classes=[MegaTable, SmallTable],
10+
table_classes=[], # During models import some circular import happening for some reason.
1311
migrations_folder_path=os.path.join(CURRENT_DIRECTORY, "piccolo_migrations"), # noqa: PTH118
1412
commands=[],
1513
)

tests/test_extra_node.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import typing
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
from piccolo.columns.column_types import Varchar
6+
from piccolo.engine import engine_finder
7+
from piccolo.table import Table
8+
9+
from psqlpy_piccolo import PSQLPyEngine
10+
from tests.conftest import AsyncMock
11+
12+
13+
def test_extra_nodes() -> None:
14+
"""Make sure that other nodes can be queried."""
15+
test_engine = engine_finder()
16+
assert test_engine is not None
17+
18+
test_engine = typing.cast(PSQLPyEngine, test_engine)
19+
20+
EXTRA_NODE: typing.Final = MagicMock(spec=PSQLPyEngine(config=test_engine.config)) # noqa: N806
21+
EXTRA_NODE.run_querystring = AsyncMock(return_value=[])
22+
23+
DB: typing.Final = PSQLPyEngine( # noqa: N806
24+
config=test_engine.config,
25+
extra_nodes={"read_1": EXTRA_NODE},
26+
)
27+
28+
class Manager(Table, db=DB): # type: ignore[call-arg]
29+
name = Varchar()
30+
31+
# Make sure the node is queried
32+
Manager.select().run_sync(node="read_1")
33+
assert EXTRA_NODE.run_querystring.called
34+
35+
# Make sure that a non existent node raises an error
36+
with pytest.raises(KeyError):
37+
Manager.select().run_sync(node="read_2")

tests/test_nested_transaction.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from piccolo.engine.exceptions import TransactionError
3+
4+
from tests.test_apps.music.tables import Manager
5+
6+
7+
async def test_nested_transactions() -> None:
8+
"""Make sure nested transactions databases work as expected."""
9+
async with Manager._meta.db.transaction():
10+
await Manager(name="Bob").save().run()
11+
12+
async with Manager._meta.db.transaction():
13+
await Manager(name="Dave").save().run()
14+
15+
assert await Manager.count().run() == 2 # noqa: PLR2004
16+
17+
assert await Manager.count().run() == 2 # noqa: PLR2004
18+
19+
20+
async def test_nested_transactions_error() -> None:
21+
"""Make sure nested transactions databases work as expected."""
22+
async with Manager._meta.db.transaction():
23+
await Manager(name="Bob").save().run()
24+
25+
with pytest.raises(TransactionError):
26+
async with Manager._meta.db.transaction(allow_nested=False):
27+
await Manager(name="Dave").save().run()
28+
29+
assert await Manager.count().run() == 1
30+
31+
assert await Manager.count().run() == 1

tests/test_pool.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
import asyncio
2+
import os
3+
import tempfile
14
import typing
5+
from unittest.mock import call, patch
6+
7+
from piccolo.engine.sqlite import SQLiteEngine
28

39
from psqlpy_piccolo import PSQLPyEngine
410
from tests.test_apps.music.tables import Manager
@@ -22,3 +28,49 @@ async def test_make_query() -> None:
2228
assert "Bob" in [instance["name"] for instance in response]
2329

2430
await Manager._meta.db.close_connection_pool()
31+
32+
33+
async def test_make_many_queries() -> None:
34+
await Manager._meta.db.start_connection_pool()
35+
36+
await Manager(name="Bob").save().run()
37+
38+
async def get_data() -> None:
39+
response = await Manager.select().run()
40+
assert response[0]["name"] == "Bob"
41+
42+
await asyncio.gather(*[get_data() for _ in range(500)])
43+
44+
await Manager._meta.db.close_connection_pool()
45+
46+
47+
async def test_proxy_methods() -> None:
48+
engine: typing.Final = typing.cast(PSQLPyEngine, Manager._meta.db)
49+
50+
# Deliberate typo ('nnn'):
51+
await engine.start_connnection_pool()
52+
assert engine.pool is not None
53+
54+
# Deliberate typo ('nnn'):
55+
await engine.close_connnection_pool()
56+
assert engine.pool is None
57+
58+
59+
async def test_warnings() -> None:
60+
sqlite_file = os.path.join(tempfile.gettempdir(), "engine.sqlite") # noqa: PTH118
61+
engine = SQLiteEngine(path=sqlite_file)
62+
63+
with patch("piccolo.engine.base.colored_warning") as colored_warning:
64+
await engine.start_connection_pool()
65+
await engine.close_connection_pool()
66+
67+
assert colored_warning.call_args_list == [
68+
call(
69+
"Connection pooling is not supported for sqlite.",
70+
stacklevel=3,
71+
),
72+
call(
73+
"Connection pooling is not supported for sqlite.",
74+
stacklevel=3,
75+
),
76+
]

tests/test_transaction.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
5+
import pytest
6+
7+
from psqlpy_piccolo import PSQLPyEngine
8+
from tests.test_apps.music.tables import Band, Manager
9+
10+
11+
def test_atomic_error_statement() -> None:
12+
"""Make sure queries in a transaction aren't committed if a query fails."""
13+
atomic = Band._meta.db.atomic()
14+
atomic.add(
15+
Band.raw("MALFORMED QUERY ... SHOULD ERROR"),
16+
)
17+
with pytest.raises(Exception): # noqa: B017, PT011
18+
atomic.run_sync()
19+
20+
21+
def test_atomic_succeeds_statement() -> None:
22+
"""Make sure that when atomic is run successfully the database is modified accordingly."""
23+
atomic = Band._meta.db.atomic()
24+
atomic.add(Manager.insert(Manager(name="test-manager-name")))
25+
atomic.run_sync()
26+
assert Manager.count().run_sync() == 1
27+
28+
29+
async def test_atomic_pool() -> None:
30+
"""Make sure atomic works correctly when a connection pool is active."""
31+
engine = Manager._meta.db
32+
await engine.start_connection_pool()
33+
34+
atomic = engine.atomic()
35+
atomic.add(Manager.insert(Manager(name="test-manager-name")))
36+
37+
await atomic.run()
38+
await engine.close_connection_pool()
39+
40+
assert Manager.count().run_sync() == 1
41+
42+
43+
async def test_transaction_error() -> None:
44+
"""Make sure queries in a transaction aren't committed if a query fails."""
45+
with pytest.raises(Exception): # noqa: B017, PT011, PT012
46+
async with Manager._meta.db.transaction():
47+
await Manager.insert(Manager(name="test-manager-name"))
48+
await Manager.raw("MALFORMED QUERY ... SHOULD ERROR")
49+
50+
assert Manager.count().run_sync() == 0
51+
52+
53+
async def test_transaction_succeeds() -> None:
54+
async with Manager._meta.db.transaction():
55+
await Manager.insert(Manager(name="test-manager-name"))
56+
57+
assert Manager.count().run_sync() == 1
58+
59+
60+
async def test_transaction_manual_commit() -> None:
61+
async with Band._meta.db.transaction() as transaction:
62+
await Manager.insert(Manager(name="test-manager-name"))
63+
await transaction.commit()
64+
65+
assert Manager.count().run_sync() == 1
66+
67+
68+
async def test_transaction_manual_rollback() -> None:
69+
async with Band._meta.db.transaction() as transaction:
70+
await Manager.insert(Manager(name="test-manager-name"))
71+
await transaction.rollback()
72+
73+
assert Manager.count().run_sync() == 0
74+
75+
76+
async def test_transaction_id() -> None:
77+
"""An extra sanity check, that the transaction id is the same for each query inside the transaction block."""
78+
79+
async def get_transaction_ids() -> list[str]:
80+
responses = []
81+
async with Band._meta.db.transaction():
82+
responses.append(await Manager.raw("SELECT txid_current()").run())
83+
responses.append(await Manager.raw("SELECT txid_current()").run())
84+
85+
return [response[0]["txid_current"] for response in responses]
86+
87+
transaction_ids: typing.Final = await get_transaction_ids()
88+
assert len(set(transaction_ids)) == 1
89+
90+
next_transaction_ids: typing.Final = await get_transaction_ids()
91+
assert len(set(next_transaction_ids)) == 1
92+
assert next_transaction_ids[0] != transaction_ids[0]
93+
94+
95+
async def test_transaction_exists() -> None:
96+
"""Make sure we can detect when code is within a transaction."""
97+
engine: typing.Final = typing.cast(PSQLPyEngine, Manager._meta.db)
98+
99+
async with engine.transaction():
100+
assert engine.transaction_exists()
101+
102+
assert not engine.transaction_exists()
103+
104+
105+
async def test_transaction_savepoint() -> None:
106+
async with Manager._meta.db.transaction() as transaction:
107+
await Manager.insert(Manager(name="Manager 1"))
108+
savepoint = await transaction.savepoint()
109+
await Manager.insert(Manager(name="Manager 2"))
110+
await savepoint.rollback_to()
111+
112+
assert await Manager.select(Manager.name).run() == [{"name": "Manager 1"}]
113+
114+
115+
async def test_transaction_named_savepoint() -> None:
116+
async with Manager._meta.db.transaction() as transaction:
117+
await Manager.insert(Manager(name="Manager 1"))
118+
await transaction.savepoint("my_savepoint1")
119+
await Manager.insert(Manager(name="Manager 2"))
120+
await transaction.savepoint("my_savepoint2")
121+
await Manager.insert(Manager(name="Manager 3"))
122+
await transaction.rollback_to("my_savepoint1")
123+
124+
assert await Manager.select(Manager.name).run() == [{"name": "Manager 1"}]
125+
126+
127+
async def test_savepoint_sqli_checks() -> None:
128+
with pytest.raises(ValueError): # noqa: PT011
129+
async with Manager._meta.db.transaction() as transaction:
130+
await transaction.savepoint(
131+
"my_savepoint; SELECT * FROM Manager",
132+
)

0 commit comments

Comments
 (0)