1
+ from __future__ import annotations
2
+
1
3
import contextvars
2
- import types
3
4
from dataclasses import dataclass
4
- from typing import Any , Dict , Generator , List , Mapping , Optional , Sequence , Type , Union
5
+ from typing import TYPE_CHECKING , Any , Generator , Mapping , Sequence
5
6
6
7
from piccolo .engine .base import BaseBatch , Engine , validate_savepoint_name
7
8
from piccolo .engine .exceptions import TransactionError
8
9
from piccolo .query .base import DDL , Query
9
- from piccolo .querystring import QueryString
10
10
from piccolo .utils .sync import run_sync
11
11
from piccolo .utils .warnings import Level , colored_warning
12
12
from psqlpy import Connection , ConnectionPool , Cursor , Transaction
13
13
from psqlpy .exceptions import RustPSQLDriverPyBaseError
14
14
from typing_extensions import Self
15
15
16
+ if TYPE_CHECKING :
17
+ import types
18
+
19
+ from piccolo .querystring import QueryString
20
+
16
21
17
22
@dataclass
18
23
class AsyncBatch (BaseBatch ):
@@ -23,8 +28,8 @@ class AsyncBatch(BaseBatch):
23
28
batch_size : int
24
29
25
30
# Set internally
26
- _transaction : Optional [ Transaction ] = None
27
- _cursor : Optional [ Cursor ] = None
31
+ _transaction : Transaction | None = None
32
+ _cursor : Cursor | None = None
28
33
29
34
@property
30
35
def cursor (self ) -> Cursor :
@@ -37,19 +42,19 @@ def cursor(self) -> Cursor:
37
42
raise ValueError ("_cursor not set" )
38
43
return self ._cursor
39
44
40
- async def next (self ) -> List [ Dict [str , Any ]]:
45
+ async def next (self ) -> list [ dict [str , Any ]]:
41
46
"""Retrieve next batch from the Cursor.
42
47
43
48
### Returns:
44
- List of dicts of results.
49
+ list of dicts of results.
45
50
"""
46
51
data = await self .cursor .fetch (self .batch_size )
47
52
return data .result ()
48
53
49
54
def __aiter__ (self : Self ) -> Self :
50
55
return self
51
56
52
- async def __anext__ (self : Self ) -> List [ Dict [str , Any ]]:
57
+ async def __anext__ (self : Self ) -> list [ dict [str , Any ]]:
53
58
response = await self .next ()
54
59
if response == []:
55
60
raise StopAsyncIteration
@@ -70,9 +75,9 @@ async def __aenter__(self: Self) -> Self:
70
75
71
76
async def __aexit__ (
72
77
self : Self ,
73
- exception_type : Optional [ Type [ BaseException ]] ,
74
- exception : Optional [ BaseException ] ,
75
- traceback : Optional [ types .TracebackType ] ,
78
+ exception_type : type [ BaseException ] | None ,
79
+ exception : BaseException | None ,
80
+ traceback : types .TracebackType | None ,
76
81
) -> bool :
77
82
if exception :
78
83
await self ._transaction .rollback () # type: ignore[union-attr]
@@ -98,19 +103,19 @@ class Atomic:
98
103
99
104
__slots__ = ("engine" , "queries" )
100
105
101
- def __init__ (self : Self , engine : " PSQLPyEngine" ) -> None :
106
+ def __init__ (self : Self , engine : PSQLPyEngine ) -> None :
102
107
"""Initialize programmatically configured atomic transaction.
103
108
104
109
### Parameters:
105
110
- `engine`: engine for query executing.
106
111
"""
107
112
self .engine = engine
108
- self .queries : List [ Union [ Query [Any , Any ], DDL ] ] = []
113
+ self .queries : list [ Query [Any , Any ] | DDL ] = []
109
114
110
115
def __await__ (self : Self ) -> Generator [Any , None , None ]:
111
116
return self .run ().__await__ ()
112
117
113
- def add (self : Self , * query : Union [ Query [Any , Any ], DDL ] ) -> None :
118
+ def add (self : Self , * query : Query [Any , Any ] | DDL ) -> None :
114
119
"""Add query to atomic transaction.
115
120
116
121
### Params:
@@ -128,7 +133,7 @@ async def run(self: Self) -> None:
128
133
if isinstance (query , (Query , DDL , Create , GetOrCreate )):
129
134
await query .run ()
130
135
else :
131
- raise ValueError ("Unrecognised query" )
136
+ raise TypeError ("Unrecognised query" ) # noqa: TRY301
132
137
self .queries = []
133
138
except Exception as exception :
134
139
self .queries = []
@@ -142,7 +147,7 @@ def run_sync(self: Self) -> None:
142
147
class Savepoint :
143
148
"""PostgreSQL `SAVEPOINT` representation in Python."""
144
149
145
- def __init__ (self : Self , name : str , transaction : " PostgresTransaction" ) -> None :
150
+ def __init__ (self : Self , name : str , transaction : PostgresTransaction ) -> None :
146
151
"""Initialize new `SAVEPOINT`.
147
152
148
153
### Parameters:
@@ -179,7 +184,7 @@ class PostgresTransaction:
179
184
180
185
"""
181
186
182
- def __init__ (self : Self , engine : " PSQLPyEngine" , allow_nested : bool = True ) -> None :
187
+ def __init__ (self : Self , engine : PSQLPyEngine , allow_nested : bool = True ) -> None :
183
188
"""Initialize new transaction.
184
189
185
190
### Parameters:
@@ -204,7 +209,7 @@ def __init__(self: Self, engine: "PSQLPyEngine", allow_nested: bool = True) -> N
204
209
"aren't allowed." ,
205
210
)
206
211
207
- async def __aenter__ (self : Self ) -> "PostgresTransaction" :
212
+ async def __aenter__ (self : Self ) -> Self :
208
213
if self ._parent is not None :
209
214
return self ._parent
210
215
@@ -218,9 +223,9 @@ async def __aenter__(self: Self) -> "PostgresTransaction":
218
223
219
224
async def __aexit__ (
220
225
self : Self ,
221
- exception_type : Optional [ Type [ BaseException ]] ,
222
- exception : Optional [ BaseException ] ,
223
- traceback : Optional [ types .TracebackType ] ,
226
+ exception_type : type [ BaseException ] | None ,
227
+ exception : BaseException | None ,
228
+ traceback : types .TracebackType | None ,
224
229
) -> bool :
225
230
if self ._parent :
226
231
return exception is None
@@ -271,7 +276,7 @@ def get_savepoint_id(self: Self) -> int:
271
276
self ._savepoint_id += 1
272
277
return self ._savepoint_id
273
278
274
- async def savepoint (self : Self , name : Optional [ str ] = None ) -> Savepoint :
279
+ async def savepoint (self : Self , name : str | None = None ) -> Savepoint :
275
280
"""Create new savepoint.
276
281
277
282
### Parameters:
@@ -351,11 +356,11 @@ class PSQLPyEngine(Engine[PostgresTransaction]):
351
356
352
357
def __init__ (
353
358
self : Self ,
354
- config : Dict [str , Any ],
359
+ config : dict [str , Any ],
355
360
extensions : Sequence [str ] = ("uuid-ossp" ,),
356
361
log_queries : bool = False ,
357
362
log_responses : bool = False ,
358
- extra_nodes : Optional [ Mapping [str , " PSQLPyEngine" ]] = None ,
363
+ extra_nodes : Mapping [str , PSQLPyEngine ] | None = None ,
359
364
) -> None :
360
365
"""Initialize `PSQLPyEngine`.
361
366
@@ -421,7 +426,7 @@ def __init__(
421
426
self .log_queries = log_queries
422
427
self .log_responses = log_responses
423
428
self .extra_nodes = extra_nodes
424
- self .pool : Optional [ ConnectionPool ] = None
429
+ self .pool : ConnectionPool | None = None
425
430
database_name = config .get ("database" , "Unknown" )
426
431
self .current_transaction = contextvars .ContextVar (
427
432
f"pg_current_transaction_{ database_name } " ,
@@ -449,7 +454,7 @@ def _parse_raw_version_string(version_string: str) -> float:
449
454
async def get_version (self : Self ) -> float :
450
455
"""Retrieve the version of Postgres being run."""
451
456
try :
452
- response : Sequence [Dict [str , Any ]] = await self ._run_in_new_connection (
457
+ response : Sequence [dict [str , Any ]] = await self ._run_in_new_connection (
453
458
"SHOW server_version" ,
454
459
)
455
460
except ConnectionRefusedError as exception :
@@ -475,7 +480,7 @@ async def prep_database(self: Self) -> None:
475
480
await self ._run_in_new_connection (
476
481
f'CREATE EXTENSION IF NOT EXISTS "{ extension } "' ,
477
482
)
478
- except RustPSQLDriverPyBaseError :
483
+ except RustPSQLDriverPyBaseError : # noqa: PERF203
479
484
colored_warning (
480
485
f"=> Unable to create { extension } extension - some "
481
486
"functionality may not behave as expected. Make sure "
@@ -487,7 +492,7 @@ async def prep_database(self: Self) -> None:
487
492
488
493
async def start_connnection_pool (
489
494
self : Self ,
490
- ** kwargs : Dict [str , Any ],
495
+ ** _kwargs : dict [str , Any ],
491
496
) -> None :
492
497
"""Start new connection pool.
493
498
@@ -504,7 +509,7 @@ async def start_connnection_pool(
504
509
)
505
510
return await self .start_connection_pool ()
506
511
507
- async def close_connnection_pool (self : Self , ** kwargs : Dict [str , Any ]) -> None :
512
+ async def close_connnection_pool (self : Self , ** _kwargs : dict [str , Any ]) -> None :
508
513
"""Close connection pool."""
509
514
colored_warning (
510
515
"`close_connnection_pool` is a typo - please change it to "
@@ -513,7 +518,7 @@ async def close_connnection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
513
518
)
514
519
return await self .close_connection_pool ()
515
520
516
- async def start_connection_pool (self : Self , ** kwargs : Dict [str , Any ]) -> None :
521
+ async def start_connection_pool (self : Self , ** kwargs : dict [str , Any ]) -> None :
517
522
"""Start new connection pool.
518
523
519
524
Create and start new connection pool.
@@ -530,9 +535,6 @@ async def start_connection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
530
535
else :
531
536
config = dict (self .config )
532
537
config .update (** kwargs )
533
- print ("----------------" )
534
- print (config )
535
- print ("----------------" )
536
538
self .pool = ConnectionPool (
537
539
db_name = config .pop ("database" , None ),
538
540
username = config .pop ("user" , None ),
@@ -549,7 +551,7 @@ async def close_connection_pool(self) -> None:
549
551
colored_warning ("No pool is running." )
550
552
551
553
async def get_new_connection (self ) -> Connection :
552
- """Returns a new connection - doesn't retrieve it from the pool."""
554
+ """Return a new connection - doesn't retrieve it from the pool."""
553
555
if self .pool :
554
556
return await self .pool .connection ()
555
557
@@ -562,11 +564,21 @@ async def get_new_connection(self) -> Connection:
562
564
)
563
565
).connection ()
564
566
567
+ def transform_response_to_dicts (
568
+ self ,
569
+ results : list [dict [str , Any ]] | dict [str , Any ],
570
+ ) -> list [dict [str , Any ]]:
571
+ """Transform result to list of dicts."""
572
+ if isinstance (results , list ):
573
+ return results
574
+
575
+ return [results ]
576
+
565
577
async def batch (
566
578
self : Self ,
567
579
query : Query [Any , Any ],
568
580
batch_size : int = 100 ,
569
- node : Optional [ str ] = None ,
581
+ node : str | None = None ,
570
582
) -> AsyncBatch :
571
583
"""Create new `AsyncBatch`.
572
584
@@ -588,8 +600,8 @@ async def batch(
588
600
async def _run_in_pool (
589
601
self : Self ,
590
602
query : str ,
591
- args : Optional [ Sequence [Any ]] = None ,
592
- ) -> List [ Dict [str , Any ]]:
603
+ args : Sequence [Any ] | None = None ,
604
+ ) -> list [ dict [str , Any ]]:
593
605
"""Run query in the pool.
594
606
595
607
### Parameters:
@@ -613,8 +625,8 @@ async def _run_in_pool(
613
625
async def _run_in_new_connection (
614
626
self : Self ,
615
627
query : str ,
616
- args : Optional [ Sequence [Any ]] = None ,
617
- ) -> List [ Dict [str , Any ]]:
628
+ args : Sequence [Any ] | None = None ,
629
+ ) -> list [ dict [str , Any ]]:
618
630
"""Run query in a new connection.
619
631
620
632
### Parameters:
@@ -625,21 +637,19 @@ async def _run_in_new_connection(
625
637
Result from the database as a list of dicts.
626
638
"""
627
639
connection = await self .get_new_connection ()
628
- try :
629
- results = await connection .execute (
630
- querystring = query ,
631
- parameters = args ,
632
- )
633
- except RustPSQLDriverPyBaseError as exception :
634
- raise exception
640
+ results = await connection .execute (
641
+ querystring = query ,
642
+ parameters = args ,
643
+ )
644
+ connection .back_to_pool ()
635
645
636
646
return results .result ()
637
647
638
648
async def run_querystring (
639
649
self : Self ,
640
650
querystring : QueryString ,
641
651
in_pool : bool = True ,
642
- ) -> List [ Dict [str , Any ]]:
652
+ ) -> list [ dict [str , Any ]]:
643
653
"""Run querystring.
644
654
645
655
### Parameters:
@@ -649,9 +659,6 @@ async def run_querystring(
649
659
### Returns:
650
660
Result from the database as a list of dicts.
651
661
"""
652
- print ("------------------" )
653
- print ("RUN" , querystring )
654
- print ("------------------" )
655
662
query , query_args = querystring .compile_string (engine_type = self .engine_type )
656
663
657
664
query_id = self .get_query_id ()
@@ -674,14 +681,14 @@ async def run_querystring(
674
681
675
682
if self .log_responses :
676
683
self .print_response (query_id = query_id , response = response )
677
- print ( response )
684
+
678
685
return response
679
686
680
687
async def run_ddl (
681
688
self : Self ,
682
689
ddl : str ,
683
690
in_pool : bool = True ,
684
- ) -> List [ Dict [str , Any ]]:
691
+ ) -> list [ dict [str , Any ]]:
685
692
"""Run ddl query.
686
693
687
694
### Parameters:
@@ -697,7 +704,7 @@ async def run_ddl(
697
704
current_transaction = self .current_transaction .get ()
698
705
if current_transaction :
699
706
raw_response = await current_transaction .connection .fetch (ddl )
700
- raw_response .result ()
707
+ response = raw_response .result ()
701
708
elif in_pool and self .pool :
702
709
response = await self ._run_in_pool (ddl )
703
710
else :
0 commit comments