Skip to content

Commit 64833b9

Browse files
committed
feat: implement query cancellation
1 parent a2faa95 commit 64833b9

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

wherobots/db/connection.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def __listen(self):
131131
# Incoming state transitions are handled here.
132132
if query.state == ExecutionState.SUCCEEDED:
133133
self.__request_results(execution_id)
134+
elif query.state == ExecutionState.CANCELLED:
135+
logging.info("Query %s has been cancelled.", execution_id)
136+
self.__queries.pop(execution_id)
134137
elif query.state == ExecutionState.FAILED:
135138
# Don't do anything here; the ERROR event is coming with more
136139
# details.
@@ -230,7 +233,14 @@ def __request_results(self, execution_id: str) -> None:
230233
self.__send(request)
231234

232235
def __cancel_query(self, execution_id: str) -> None:
233-
query = self.__queries.pop(execution_id)
234-
if query:
235-
logging.info("Cancelled query %s.", execution_id)
236-
# TODO: when protocol supports it, send cancellation request.
236+
"""Cancels the query with the given execution ID."""
237+
query = self.__queries.get(execution_id)
238+
if not query:
239+
return
240+
241+
request = {
242+
"kind": RequestKind.CANCEL.value,
243+
"execution_id": execution_id,
244+
}
245+
logging.info("Cancelling query %s...", execution_id)
246+
self.__send(request)

wherobots/db/constants.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import auto
2+
from packaging.version import Version
23
from strenum import LowercaseStrEnum, StrEnum
34

45
from .region import Region
@@ -15,7 +16,7 @@
1516
DEFAULT_REUSE_SESSION: bool = True
1617

1718
MAX_MESSAGE_SIZE: int = 100 * 2**20 # 100MiB
18-
PROTOCOL_VERSION: str = "1.0.0"
19+
PROTOCOL_VERSION: Version = Version("1.0.0")
1920

2021

2122
class ExecutionState(LowercaseStrEnum):
@@ -31,6 +32,9 @@ class ExecutionState(LowercaseStrEnum):
3132
SUCCEEDED = auto()
3233
"The SQL session has reported the query has completed successfully."
3334

35+
CANCELLED = auto()
36+
"The SQL session has reported the query has been cancelled."
37+
3438
FAILED = auto()
3539
"The SQL session has reported the query has failed."
3640

@@ -41,12 +45,17 @@ class ExecutionState(LowercaseStrEnum):
4145
"The driver has completed processing the query results."
4246

4347
def is_terminal_state(self):
44-
return self in (ExecutionState.COMPLETED, ExecutionState.FAILED)
48+
return self in (
49+
ExecutionState.COMPLETED,
50+
ExecutionState.CANCELLED,
51+
ExecutionState.FAILED,
52+
)
4553

4654

4755
class RequestKind(LowercaseStrEnum):
4856
EXECUTE_SQL = auto()
4957
RETRIEVE_RESULTS = auto()
58+
CANCEL = auto()
5059

5160

5261
class EventKind(LowercaseStrEnum):

wherobots/db/cursor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
8989

9090
def fetchone(self):
9191
results = self.__get_results()[self.__current_row :]
92-
if not results:
92+
if len(results) == 0:
9393
return None
9494
self.__current_row += 1
9595
return results[0]
@@ -105,7 +105,8 @@ def fetchall(self):
105105

106106
def close(self):
107107
"""Close the cursor."""
108-
pass
108+
if self.__results is None and self.__current_execution_id:
109+
self.__cancel_fn(self.__current_execution_id)
109110

110111
def __iter__(self):
111112
return self

wherobots/db/driver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
A PEP-0249 compatible driver for interfacing with Wherobots DB.
44
"""
55

6+
from importlib import metadata
7+
from importlib.metadata import PackageNotFoundError
68
import logging
9+
from packaging.version import Version
710
import platform
8-
import urllib.parse
911
import queue
10-
from importlib import metadata
11-
from importlib.metadata import PackageNotFoundError
12-
1312
import requests
1413
import tenacity
15-
import threading
1614
from typing import Union
15+
import urllib.parse
1716
import websockets.sync.client
1817

18+
from .connection import Connection
1919
from .constants import (
2020
DEFAULT_ENDPOINT,
2121
DEFAULT_REGION,
@@ -36,7 +36,6 @@
3636
)
3737
from .region import Region
3838
from .runtime import Runtime
39-
from .connection import Connection
4039

4140
apilevel = "2.0"
4241
threadsafety = 1
@@ -163,14 +162,15 @@ def http_to_ws(uri: str) -> str:
163162

164163
def connect_direct(
165164
uri: str,
165+
protocol: Version = PROTOCOL_VERSION,
166166
headers: dict[str, str] = None,
167167
read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
168168
results_format: Union[ResultsFormat, None] = None,
169169
data_compression: Union[DataCompression, None] = None,
170170
geometry_representation: Union[GeometryRepresentation, None] = None,
171171
) -> Connection:
172172
q = queue.SimpleQueue()
173-
uri_with_protocol = f"{uri}/{PROTOCOL_VERSION}"
173+
uri_with_protocol = f"{uri}/{protocol}"
174174

175175
try:
176176
logging.info("Connecting to SQL session at %s ...", uri_with_protocol)

0 commit comments

Comments
 (0)