Skip to content

Commit bcaa733

Browse files
committed
Make tests work
1 parent cf9ac9f commit bcaa733

File tree

8 files changed

+398
-495
lines changed

8 files changed

+398
-495
lines changed

piccolo_conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from piccolo.conf.apps import AppRegistry
1212

13-
from psqlpy_piccolo import PSQLPyEngine as Engine
13+
from psqlpy_piccolo import PSQLPyEngine
1414

15-
DB = Engine(
15+
DB = PSQLPyEngine(
1616
config={
1717
"host": os.environ.get("PG_HOST", "127.0.0.1"),
1818
"port": os.environ.get("PG_PORT", 5432),

poetry.lock

Lines changed: 93 additions & 187 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
-715 Bytes
Binary file not shown.

psqlpy_piccolo/engine.py

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
from __future__ import annotations
2+
13
import contextvars
2-
import types
34
from dataclasses import dataclass
4-
from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Type, Union
5+
from typing import TYPE_CHECKING, Any, Generator, Mapping, Sequence
56

67
from piccolo.engine.base import BaseBatch, Engine, validate_savepoint_name
78
from piccolo.engine.exceptions import TransactionError
89
from piccolo.query.base import DDL, Query
9-
from piccolo.querystring import QueryString
1010
from piccolo.utils.sync import run_sync
1111
from piccolo.utils.warnings import Level, colored_warning
1212
from psqlpy import Connection, ConnectionPool, Cursor, Transaction
1313
from psqlpy.exceptions import RustPSQLDriverPyBaseError
1414
from typing_extensions import Self
1515

16+
if TYPE_CHECKING:
17+
import types
18+
19+
from piccolo.querystring import QueryString
20+
1621

1722
@dataclass
1823
class AsyncBatch(BaseBatch):
@@ -23,8 +28,8 @@ class AsyncBatch(BaseBatch):
2328
batch_size: int
2429

2530
# Set internally
26-
_transaction: Optional[Transaction] = None
27-
_cursor: Optional[Cursor] = None
31+
_transaction: Transaction | None = None
32+
_cursor: Cursor | None = None
2833

2934
@property
3035
def cursor(self) -> Cursor:
@@ -37,19 +42,19 @@ def cursor(self) -> Cursor:
3742
raise ValueError("_cursor not set")
3843
return self._cursor
3944

40-
async def next(self) -> List[Dict[str, Any]]:
45+
async def next(self) -> list[dict[str, Any]]:
4146
"""Retrieve next batch from the Cursor.
4247
4348
### Returns:
44-
List of dicts of results.
49+
list of dicts of results.
4550
"""
4651
data = await self.cursor.fetch(self.batch_size)
4752
return data.result()
4853

4954
def __aiter__(self: Self) -> Self:
5055
return self
5156

52-
async def __anext__(self: Self) -> List[Dict[str, Any]]:
57+
async def __anext__(self: Self) -> list[dict[str, Any]]:
5358
response = await self.next()
5459
if response == []:
5560
raise StopAsyncIteration
@@ -70,9 +75,9 @@ async def __aenter__(self: Self) -> Self:
7075

7176
async def __aexit__(
7277
self: Self,
73-
exception_type: Optional[Type[BaseException]],
74-
exception: Optional[BaseException],
75-
traceback: Optional[types.TracebackType],
78+
exception_type: type[BaseException] | None,
79+
exception: BaseException | None,
80+
traceback: types.TracebackType | None,
7681
) -> bool:
7782
if exception:
7883
await self._transaction.rollback() # type: ignore[union-attr]
@@ -98,19 +103,19 @@ class Atomic:
98103

99104
__slots__ = ("engine", "queries")
100105

101-
def __init__(self: Self, engine: "PSQLPyEngine") -> None:
106+
def __init__(self: Self, engine: PSQLPyEngine) -> None:
102107
"""Initialize programmatically configured atomic transaction.
103108
104109
### Parameters:
105110
- `engine`: engine for query executing.
106111
"""
107112
self.engine = engine
108-
self.queries: List[Union[Query[Any, Any], DDL]] = []
113+
self.queries: list[Query[Any, Any] | DDL] = []
109114

110115
def __await__(self: Self) -> Generator[Any, None, None]:
111116
return self.run().__await__()
112117

113-
def add(self: Self, *query: Union[Query[Any, Any], DDL]) -> None:
118+
def add(self: Self, *query: Query[Any, Any] | DDL) -> None:
114119
"""Add query to atomic transaction.
115120
116121
### Params:
@@ -128,7 +133,7 @@ async def run(self: Self) -> None:
128133
if isinstance(query, (Query, DDL, Create, GetOrCreate)):
129134
await query.run()
130135
else:
131-
raise ValueError("Unrecognised query")
136+
raise TypeError("Unrecognised query") # noqa: TRY301
132137
self.queries = []
133138
except Exception as exception:
134139
self.queries = []
@@ -142,7 +147,7 @@ def run_sync(self: Self) -> None:
142147
class Savepoint:
143148
"""PostgreSQL `SAVEPOINT` representation in Python."""
144149

145-
def __init__(self: Self, name: str, transaction: "PostgresTransaction") -> None:
150+
def __init__(self: Self, name: str, transaction: PostgresTransaction) -> None:
146151
"""Initialize new `SAVEPOINT`.
147152
148153
### Parameters:
@@ -179,7 +184,7 @@ class PostgresTransaction:
179184
180185
"""
181186

182-
def __init__(self: Self, engine: "PSQLPyEngine", allow_nested: bool = True) -> None:
187+
def __init__(self: Self, engine: PSQLPyEngine, allow_nested: bool = True) -> None:
183188
"""Initialize new transaction.
184189
185190
### Parameters:
@@ -204,7 +209,7 @@ def __init__(self: Self, engine: "PSQLPyEngine", allow_nested: bool = True) -> N
204209
"aren't allowed.",
205210
)
206211

207-
async def __aenter__(self: Self) -> "PostgresTransaction":
212+
async def __aenter__(self: Self) -> Self:
208213
if self._parent is not None:
209214
return self._parent
210215

@@ -218,9 +223,9 @@ async def __aenter__(self: Self) -> "PostgresTransaction":
218223

219224
async def __aexit__(
220225
self: Self,
221-
exception_type: Optional[Type[BaseException]],
222-
exception: Optional[BaseException],
223-
traceback: Optional[types.TracebackType],
226+
exception_type: type[BaseException] | None,
227+
exception: BaseException | None,
228+
traceback: types.TracebackType | None,
224229
) -> bool:
225230
if self._parent:
226231
return exception is None
@@ -271,7 +276,7 @@ def get_savepoint_id(self: Self) -> int:
271276
self._savepoint_id += 1
272277
return self._savepoint_id
273278

274-
async def savepoint(self: Self, name: Optional[str] = None) -> Savepoint:
279+
async def savepoint(self: Self, name: str | None = None) -> Savepoint:
275280
"""Create new savepoint.
276281
277282
### Parameters:
@@ -351,11 +356,11 @@ class PSQLPyEngine(Engine[PostgresTransaction]):
351356

352357
def __init__(
353358
self: Self,
354-
config: Dict[str, Any],
359+
config: dict[str, Any],
355360
extensions: Sequence[str] = ("uuid-ossp",),
356361
log_queries: bool = False,
357362
log_responses: bool = False,
358-
extra_nodes: Optional[Mapping[str, "PSQLPyEngine"]] = None,
363+
extra_nodes: Mapping[str, PSQLPyEngine] | None = None,
359364
) -> None:
360365
"""Initialize `PSQLPyEngine`.
361366
@@ -421,7 +426,7 @@ def __init__(
421426
self.log_queries = log_queries
422427
self.log_responses = log_responses
423428
self.extra_nodes = extra_nodes
424-
self.pool: Optional[ConnectionPool] = None
429+
self.pool: ConnectionPool | None = None
425430
database_name = config.get("database", "Unknown")
426431
self.current_transaction = contextvars.ContextVar(
427432
f"pg_current_transaction_{database_name}",
@@ -449,7 +454,7 @@ def _parse_raw_version_string(version_string: str) -> float:
449454
async def get_version(self: Self) -> float:
450455
"""Retrieve the version of Postgres being run."""
451456
try:
452-
response: Sequence[Dict[str, Any]] = await self._run_in_new_connection(
457+
response: Sequence[dict[str, Any]] = await self._run_in_new_connection(
453458
"SHOW server_version",
454459
)
455460
except ConnectionRefusedError as exception:
@@ -475,7 +480,7 @@ async def prep_database(self: Self) -> None:
475480
await self._run_in_new_connection(
476481
f'CREATE EXTENSION IF NOT EXISTS "{extension}"',
477482
)
478-
except RustPSQLDriverPyBaseError:
483+
except RustPSQLDriverPyBaseError: # noqa: PERF203
479484
colored_warning(
480485
f"=> Unable to create {extension} extension - some "
481486
"functionality may not behave as expected. Make sure "
@@ -487,7 +492,7 @@ async def prep_database(self: Self) -> None:
487492

488493
async def start_connnection_pool(
489494
self: Self,
490-
**kwargs: Dict[str, Any],
495+
**_kwargs: dict[str, Any],
491496
) -> None:
492497
"""Start new connection pool.
493498
@@ -504,7 +509,7 @@ async def start_connnection_pool(
504509
)
505510
return await self.start_connection_pool()
506511

507-
async def close_connnection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
512+
async def close_connnection_pool(self: Self, **_kwargs: dict[str, Any]) -> None:
508513
"""Close connection pool."""
509514
colored_warning(
510515
"`close_connnection_pool` is a typo - please change it to "
@@ -513,7 +518,7 @@ async def close_connnection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
513518
)
514519
return await self.close_connection_pool()
515520

516-
async def start_connection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
521+
async def start_connection_pool(self: Self, **kwargs: dict[str, Any]) -> None:
517522
"""Start new connection pool.
518523
519524
Create and start new connection pool.
@@ -530,9 +535,6 @@ async def start_connection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
530535
else:
531536
config = dict(self.config)
532537
config.update(**kwargs)
533-
print("----------------")
534-
print(config)
535-
print("----------------")
536538
self.pool = ConnectionPool(
537539
db_name=config.pop("database", None),
538540
username=config.pop("user", None),
@@ -549,7 +551,7 @@ async def close_connection_pool(self) -> None:
549551
colored_warning("No pool is running.")
550552

551553
async def get_new_connection(self) -> Connection:
552-
"""Returns a new connection - doesn't retrieve it from the pool."""
554+
"""Return a new connection - doesn't retrieve it from the pool."""
553555
if self.pool:
554556
return await self.pool.connection()
555557

@@ -562,11 +564,21 @@ async def get_new_connection(self) -> Connection:
562564
)
563565
).connection()
564566

567+
def transform_response_to_dicts(
568+
self,
569+
results: list[dict[str, Any]] | dict[str, Any],
570+
) -> list[dict[str, Any]]:
571+
"""Transform result to list of dicts."""
572+
if isinstance(results, list):
573+
return results
574+
575+
return [results]
576+
565577
async def batch(
566578
self: Self,
567579
query: Query[Any, Any],
568580
batch_size: int = 100,
569-
node: Optional[str] = None,
581+
node: str | None = None,
570582
) -> AsyncBatch:
571583
"""Create new `AsyncBatch`.
572584
@@ -588,8 +600,8 @@ async def batch(
588600
async def _run_in_pool(
589601
self: Self,
590602
query: str,
591-
args: Optional[Sequence[Any]] = None,
592-
) -> List[Dict[str, Any]]:
603+
args: Sequence[Any] | None = None,
604+
) -> list[dict[str, Any]]:
593605
"""Run query in the pool.
594606
595607
### Parameters:
@@ -613,8 +625,8 @@ async def _run_in_pool(
613625
async def _run_in_new_connection(
614626
self: Self,
615627
query: str,
616-
args: Optional[Sequence[Any]] = None,
617-
) -> List[Dict[str, Any]]:
628+
args: Sequence[Any] | None = None,
629+
) -> list[dict[str, Any]]:
618630
"""Run query in a new connection.
619631
620632
### Parameters:
@@ -625,21 +637,19 @@ async def _run_in_new_connection(
625637
Result from the database as a list of dicts.
626638
"""
627639
connection = await self.get_new_connection()
628-
try:
629-
results = await connection.execute(
630-
querystring=query,
631-
parameters=args,
632-
)
633-
except RustPSQLDriverPyBaseError as exception:
634-
raise exception
640+
results = await connection.execute(
641+
querystring=query,
642+
parameters=args,
643+
)
644+
connection.back_to_pool()
635645

636646
return results.result()
637647

638648
async def run_querystring(
639649
self: Self,
640650
querystring: QueryString,
641651
in_pool: bool = True,
642-
) -> List[Dict[str, Any]]:
652+
) -> list[dict[str, Any]]:
643653
"""Run querystring.
644654
645655
### Parameters:
@@ -649,9 +659,6 @@ async def run_querystring(
649659
### Returns:
650660
Result from the database as a list of dicts.
651661
"""
652-
print("------------------")
653-
print("RUN", querystring)
654-
print("------------------")
655662
query, query_args = querystring.compile_string(engine_type=self.engine_type)
656663

657664
query_id = self.get_query_id()
@@ -674,14 +681,14 @@ async def run_querystring(
674681

675682
if self.log_responses:
676683
self.print_response(query_id=query_id, response=response)
677-
print(response)
684+
678685
return response
679686

680687
async def run_ddl(
681688
self: Self,
682689
ddl: str,
683690
in_pool: bool = True,
684-
) -> List[Dict[str, Any]]:
691+
) -> list[dict[str, Any]]:
685692
"""Run ddl query.
686693
687694
### Parameters:
@@ -697,7 +704,7 @@ async def run_ddl(
697704
current_transaction = self.current_transaction.get()
698705
if current_transaction:
699706
raw_response = await current_transaction.connection.fetch(ddl)
700-
raw_response.result()
707+
response = raw_response.result()
701708
elif in_pool and self.pool:
702709
response = await self._run_in_pool(ddl)
703710
else:

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ readme = "README.md"
77

88
[tool.poetry.dependencies]
99
python = "^3.8"
10-
psqlpy = "^0.6.0"
11-
piccolo = {version = "^1.13.1", extras = ["postgres"]}
10+
piccolo = {version = "^1.14.0", extras = ["postgres"]}
11+
psqlpy = "^0.7.4"
12+
typing-extensions = "^4.12.2"
1213

1314

1415
[tool.poetry.group.lint.dependencies]

tests/test_apps/music/piccolo_app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
from piccolo.conf.apps import AppConfig
44

5-
from tests.test_apps.mega.tables import MegaTable, SmallTable
6-
75
CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) # noqa: PTH120, PTH100
86

97

108
APP_CONFIG = AppConfig(
119
app_name="music",
12-
table_classes=[MegaTable, SmallTable],
10+
table_classes=[], # During models import some circular import happening for some reason.
1311
migrations_folder_path=os.path.join(CURRENT_DIRECTORY, "piccolo_migrations"), # noqa: PTH118
1412
commands=[],
1513
)

0 commit comments

Comments
 (0)