Skip to content

Commit 9adcf84

Browse files
committed
feat: refactor state transitions to fix support for query cancellation
1 parent d7c8ec5 commit 9adcf84

File tree

2 files changed

+47
-37
lines changed

2 files changed

+47
-37
lines changed

wherobots/db/connection.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Union
88

99
import cbor2
10+
import pandas
1011
import pyarrow
1112
import websockets.exceptions
1213
import websockets.protocol
@@ -120,61 +121,70 @@ def __listen(self) -> None:
120121
)
121122
return
122123

123-
if kind == EventKind.STATE_UPDATED:
124+
# Incoming state transitions are handled here.
125+
if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT:
124126
try:
125127
query.state = ExecutionState[message["state"].upper()]
126128
logging.info("Query %s is now %s.", execution_id, query.state)
127129
except KeyError:
128130
logging.warning("Invalid state update message for %s", execution_id)
129131
return
130132

131-
# Incoming state transitions are handled here.
132133
if query.state == ExecutionState.SUCCEEDED:
133-
self.__request_results(execution_id)
134+
# On a state_updated event telling us the query succeeded,
135+
# ask for results.
136+
if kind == EventKind.STATE_UPDATED:
137+
self.__request_results(execution_id)
138+
return
139+
140+
# Otherwise, process the results from the execution_result event.
141+
results = message.get("results")
142+
if not results or not isinstance(results, dict):
143+
logging.warning("Got no results back from %s.", execution_id)
144+
return
145+
146+
query.state = ExecutionState.COMPLETED
147+
query.handler(self._handle_results(execution_id, results))
134148
elif query.state == ExecutionState.CANCELLED:
135-
logging.info("Query %s has been cancelled.", execution_id)
149+
logging.info(
150+
"Query %s has been cancelled; returning empty results.",
151+
execution_id,
152+
)
153+
query.handler(pandas.DataFrame())
136154
self.__queries.pop(execution_id)
137155
elif query.state == ExecutionState.FAILED:
138156
# Don't do anything here; the ERROR event is coming with more
139157
# details.
140158
pass
141-
142-
elif kind == EventKind.EXECUTION_RESULT:
143-
results = message.get("results")
144-
if not results or not isinstance(results, dict):
145-
logging.warning("Got no results back from %s.", execution_id)
146-
return
147-
148-
result_bytes = results.get("result_bytes")
149-
result_format = results.get("format")
150-
result_compression = results.get("compression")
151-
logging.info(
152-
"Received %d bytes of %s-compressed %s results from %s.",
153-
len(result_bytes),
154-
result_compression,
155-
result_format,
156-
execution_id,
157-
)
158-
159-
query.state = ExecutionState.COMPLETED
160-
if result_format == ResultsFormat.JSON:
161-
query.handler(json.loads(result_bytes.decode("utf-8")))
162-
elif result_format == ResultsFormat.ARROW:
163-
buffer = pyarrow.py_buffer(result_bytes)
164-
stream = pyarrow.input_stream(buffer, result_compression)
165-
with pyarrow.ipc.open_stream(stream) as reader:
166-
query.handler(reader.read_pandas())
167-
else:
168-
query.handler(
169-
OperationalError(f"Unsupported results format {result_format}")
170-
)
171159
elif kind == EventKind.ERROR:
172160
query.state = ExecutionState.FAILED
173161
error = message.get("message")
174162
query.handler(OperationalError(error))
175163
else:
176164
logging.warning("Received unknown %s event!", kind)
177165

166+
def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any:
167+
result_bytes = results.get("result_bytes")
168+
result_format = results.get("format")
169+
result_compression = results.get("compression")
170+
logging.info(
171+
"Received %d bytes of %s-compressed %s results from %s.",
172+
len(result_bytes),
173+
result_compression,
174+
result_format,
175+
execution_id,
176+
)
177+
178+
if result_format == ResultsFormat.JSON:
179+
return json.loads(result_bytes.decode("utf-8"))
180+
elif result_format == ResultsFormat.ARROW:
181+
buffer = pyarrow.py_buffer(result_bytes)
182+
stream = pyarrow.input_stream(buffer, result_compression)
183+
with pyarrow.ipc.open_stream(stream) as reader:
184+
return reader.read_pandas()
185+
else:
186+
return OperationalError(f"Unsupported results format {result_format}")
187+
178188
def __send(self, message: dict[str, Any]) -> None:
179189
request = json.dumps(message)
180190
logging.debug("Request: %s", request)

wherobots/db/cursor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import queue
22
from typing import Any, Optional, List, Tuple
33

4-
from .errors import ProgrammingError, DatabaseError
4+
from .errors import DatabaseError, ProgrammingError
55

66
_TYPE_MAP = {
77
"object": "STRING",
@@ -110,13 +110,13 @@ def close(self) -> None:
110110
if self.__results is None and self.__current_execution_id:
111111
self.__cancel_fn(self.__current_execution_id)
112112

113-
def __iter__(self) -> Cursor:
113+
def __iter__(self):
114114
return self
115115

116116
def __next__(self) -> None:
117117
raise StopIteration
118118

119-
def __enter__(self) -> Cursor:
119+
def __enter__(self):
120120
return self
121121

122122
def __exit__(self, exc_type, exc_val, exc_tb) -> None:

0 commit comments

Comments
 (0)