Skip to content

Commit 06c0e87

Browse files
authored
Merge pull request #23 from wherobots/max/query-cancel
Fix support for query cancellation
2 parents b4dc1b7 + 7d59bfe commit 06c0e87

File tree

7 files changed

+95
-58
lines changed

7 files changed

+95
-58
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,6 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
.idea/
161+
162+
# Vim
163+
*.swp

poetry.lock

Lines changed: 15 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "wherobots-python-dbapi"
3-
version = "0.9.0"
3+
version = "0.9.1"
44
description = "Python DB-API driver for Wherobots DB"
55
authors = ["Maxime Petazzoni <[email protected]>"]
66
license = "Apache 2.0"
@@ -29,6 +29,7 @@ pytest = "^8.0.2"
2929
black = "^24.2.0"
3030
pre-commit = "^3.6.2"
3131
conventional-pre-commit = "^3.1.0"
32+
types-requests = "^2.32.0.20241016"
3233
rich = "^13.7.1"
3334

3435
[build-system]

tests/smoke.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# A simple smoke test for the DB driver.
22

33
import argparse
4+
import concurrent.futures
45
import functools
56
import logging
67
import sys
@@ -11,6 +12,7 @@
1112

1213
from wherobots.db import connect, connect_direct
1314
from wherobots.db.constants import DEFAULT_ENDPOINT
15+
from wherobots.db.connection import Connection
1416
from wherobots.db.region import Region
1517
from wherobots.db.runtime import Runtime
1618

@@ -85,8 +87,13 @@ def render(results: pandas.DataFrame):
8587
table.add_row(*r)
8688
Console().print(table)
8789

90+
def execute(conn: Connection, sql: str) -> pandas.DataFrame:
91+
with conn.cursor() as cursor:
92+
cursor.execute(sql)
93+
return cursor.fetchall()
94+
8895
with conn_func() as conn:
89-
for sql in args.sql:
90-
with conn.cursor() as cursor:
91-
cursor.execute(sql)
92-
render(cursor.fetchall())
96+
with concurrent.futures.ThreadPoolExecutor() as pool:
97+
futures = [pool.submit(execute, conn, s) for s in args.sql]
98+
for future in concurrent.futures.as_completed(futures):
99+
render(future.result())

wherobots/db/connection.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Union
88

99
import cbor2
10+
import pandas
1011
import pyarrow
1112
import websockets.exceptions
1213
import websockets.protocol
@@ -74,19 +75,19 @@ def __enter__(self):
7475
def __exit__(self, exc_type, exc_val, exc_tb):
7576
self.close()
7677

77-
def close(self):
78+
def close(self) -> None:
7879
self.__ws.close()
7980

80-
def commit(self):
81+
def commit(self) -> None:
8182
raise NotSupportedError
8283

83-
def rollback(self):
84+
def rollback(self) -> None:
8485
raise NotSupportedError
8586

8687
def cursor(self) -> Cursor:
8788
return Cursor(self.__execute_sql, self.__cancel_query)
8889

89-
def __main_loop(self):
90+
def __main_loop(self) -> None:
9091
"""Main background loop listening for messages from the SQL session."""
9192
logging.info("Starting background connection handling loop...")
9293
while self.__ws.protocol.state < websockets.protocol.State.CLOSING:
@@ -101,7 +102,7 @@ def __main_loop(self):
101102
except Exception as e:
102103
logging.exception("Error handling message from SQL session", exc_info=e)
103104

104-
def __listen(self):
105+
def __listen(self) -> None:
105106
"""Waits for the next message from the SQL session and processes it.
106107
107108
The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
@@ -120,61 +121,70 @@ def __listen(self):
120121
)
121122
return
122123

123-
if kind == EventKind.STATE_UPDATED:
124+
# Incoming state transitions are handled here.
125+
if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT:
124126
try:
125127
query.state = ExecutionState[message["state"].upper()]
126128
logging.info("Query %s is now %s.", execution_id, query.state)
127129
except KeyError:
128130
logging.warning("Invalid state update message for %s", execution_id)
129131
return
130132

131-
# Incoming state transitions are handled here.
132133
if query.state == ExecutionState.SUCCEEDED:
133-
self.__request_results(execution_id)
134+
# On a state_updated event telling us the query succeeded,
135+
# ask for results.
136+
if kind == EventKind.STATE_UPDATED:
137+
self.__request_results(execution_id)
138+
return
139+
140+
# Otherwise, process the results from the execution_result event.
141+
results = message.get("results")
142+
if not results or not isinstance(results, dict):
143+
logging.warning("Got no results back from %s.", execution_id)
144+
return
145+
146+
query.state = ExecutionState.COMPLETED
147+
query.handler(self._handle_results(execution_id, results))
134148
elif query.state == ExecutionState.CANCELLED:
135-
logging.info("Query %s has been cancelled.", execution_id)
149+
logging.info(
150+
"Query %s has been cancelled; returning empty results.",
151+
execution_id,
152+
)
153+
query.handler(pandas.DataFrame())
136154
self.__queries.pop(execution_id)
137155
elif query.state == ExecutionState.FAILED:
138156
# Don't do anything here; the ERROR event is coming with more
139157
# details.
140158
pass
141-
142-
elif kind == EventKind.EXECUTION_RESULT:
143-
results = message.get("results")
144-
if not results or not isinstance(results, dict):
145-
logging.warning("Got no results back from %s.", execution_id)
146-
return
147-
148-
result_bytes = results.get("result_bytes")
149-
result_format = results.get("format")
150-
result_compression = results.get("compression")
151-
logging.info(
152-
"Received %d bytes of %s-compressed %s results from %s.",
153-
len(result_bytes),
154-
result_compression,
155-
result_format,
156-
execution_id,
157-
)
158-
159-
query.state = ExecutionState.COMPLETED
160-
if result_format == ResultsFormat.JSON:
161-
query.handler(json.loads(result_bytes.decode("utf-8")))
162-
elif result_format == ResultsFormat.ARROW:
163-
buffer = pyarrow.py_buffer(result_bytes)
164-
stream = pyarrow.input_stream(buffer, result_compression)
165-
with pyarrow.ipc.open_stream(stream) as reader:
166-
query.handler(reader.read_pandas())
167-
else:
168-
query.handler(
169-
OperationalError(f"Unsupported results format {result_format}")
170-
)
171159
elif kind == EventKind.ERROR:
172160
query.state = ExecutionState.FAILED
173161
error = message.get("message")
174162
query.handler(OperationalError(error))
175163
else:
176164
logging.warning("Received unknown %s event!", kind)
177165

166+
def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any:
167+
result_bytes = results.get("result_bytes")
168+
result_format = results.get("format")
169+
result_compression = results.get("compression")
170+
logging.info(
171+
"Received %d bytes of %s-compressed %s results from %s.",
172+
len(result_bytes),
173+
result_compression,
174+
result_format,
175+
execution_id,
176+
)
177+
178+
if result_format == ResultsFormat.JSON:
179+
return json.loads(result_bytes.decode("utf-8"))
180+
elif result_format == ResultsFormat.ARROW:
181+
buffer = pyarrow.py_buffer(result_bytes)
182+
stream = pyarrow.input_stream(buffer, result_compression)
183+
with pyarrow.ipc.open_stream(stream) as reader:
184+
return reader.read_pandas()
185+
else:
186+
return OperationalError(f"Unsupported results format {result_format}")
187+
178188
def __send(self, message: dict[str, Any]) -> None:
179189
request = json.dumps(message)
180190
logging.debug("Request: %s", request)

wherobots/db/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ExecutionState(LowercaseStrEnum):
4444
COMPLETED = auto()
4545
"The driver has completed processing the query results."
4646

47-
def is_terminal_state(self):
47+
def is_terminal_state(self) -> bool:
4848
return self in (
4949
ExecutionState.COMPLETED,
5050
ExecutionState.CANCELLED,
@@ -97,7 +97,7 @@ class AppStatus(StrEnum):
9797
DESTROY_FAILED = auto()
9898
DESTROYED = auto()
9999

100-
def is_starting(self):
100+
def is_starting(self) -> bool:
101101
return self in (
102102
AppStatus.PENDING,
103103
AppStatus.PREPARING,
@@ -107,7 +107,7 @@ def is_starting(self):
107107
AppStatus.INITIALIZING,
108108
)
109109

110-
def is_terminal_state(self):
110+
def is_terminal_state(self) -> bool:
111111
return self in (
112112
AppStatus.PREPARE_FAILED,
113113
AppStatus.DEPLOY_FAILED,

wherobots/db/cursor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import queue
22
from typing import Any, Optional, List, Tuple
33

4-
from .errors import ProgrammingError, DatabaseError
4+
from .errors import DatabaseError, ProgrammingError
55

66
_TYPE_MAP = {
77
"object": "STRING",
@@ -16,7 +16,7 @@
1616

1717
class Cursor:
1818

19-
def __init__(self, exec_fn, cancel_fn):
19+
def __init__(self, exec_fn, cancel_fn) -> None:
2020
self.__exec_fn = exec_fn
2121
self.__cancel_fn = cancel_fn
2222

@@ -72,7 +72,7 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:
7272

7373
return self.__results
7474

75-
def execute(self, operation: str, parameters: dict[str, Any] = None):
75+
def execute(self, operation: str, parameters: dict[str, Any] = None) -> None:
7676
if self.__current_execution_id:
7777
self.__cancel_fn(self.__current_execution_id)
7878

@@ -84,38 +84,40 @@ def execute(self, operation: str, parameters: dict[str, Any] = None):
8484
sql = operation.format(**(parameters or {}))
8585
self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)
8686

87-
def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
87+
def executemany(
88+
self, operation: str, seq_of_parameters: list[dict[str, Any]]
89+
) -> None:
8890
raise NotImplementedError
8991

90-
def fetchone(self):
92+
def fetchone(self) -> Any:
9193
results = self.__get_results()[self.__current_row :]
9294
if len(results) == 0:
9395
return None
9496
self.__current_row += 1
9597
return results[0]
9698

97-
def fetchmany(self, size: int = None):
99+
def fetchmany(self, size: int = None) -> list[Any]:
98100
size = size or self.arraysize
99101
results = self.__get_results()[self.__current_row : self.__current_row + size]
100102
self.__current_row += size
101103
return results
102104

103-
def fetchall(self):
105+
def fetchall(self) -> list[Any]:
104106
return self.__get_results()[self.__current_row :]
105107

106-
def close(self):
108+
def close(self) -> None:
107109
"""Close the cursor."""
108110
if self.__results is None and self.__current_execution_id:
109111
self.__cancel_fn(self.__current_execution_id)
110112

111113
def __iter__(self):
112114
return self
113115

114-
def __next__(self):
116+
def __next__(self) -> None:
115117
raise StopIteration
116118

117119
def __enter__(self):
118120
return self
119121

120-
def __exit__(self, exc_type, exc_val, exc_tb):
122+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
121123
self.close()

0 commit comments

Comments
 (0)