|
7 | 7 | from typing import Any, Callable, Union
|
8 | 8 |
|
9 | 9 | import cbor2
|
| 10 | +import pandas |
10 | 11 | import pyarrow
|
11 | 12 | import websockets.exceptions
|
12 | 13 | import websockets.protocol
|
@@ -120,61 +121,70 @@ def __listen(self) -> None:
|
120 | 121 | )
|
121 | 122 | return
|
122 | 123 |
|
123 |
| - if kind == EventKind.STATE_UPDATED: |
| 124 | + # Incoming state transitions are handled here. |
| 125 | + if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT: |
124 | 126 | try:
|
125 | 127 | query.state = ExecutionState[message["state"].upper()]
|
126 | 128 | logging.info("Query %s is now %s.", execution_id, query.state)
|
127 | 129 | except KeyError:
|
128 | 130 | logging.warning("Invalid state update message for %s", execution_id)
|
129 | 131 | return
|
130 | 132 |
|
131 |
| - # Incoming state transitions are handled here. |
132 | 133 | if query.state == ExecutionState.SUCCEEDED:
|
133 |
| - self.__request_results(execution_id) |
| 134 | + # On a state_updated event telling us the query succeeded, |
| 135 | + # ask for results. |
| 136 | + if kind == EventKind.STATE_UPDATED: |
| 137 | + self.__request_results(execution_id) |
| 138 | + return |
| 139 | + |
| 140 | + # Otherwise, process the results from the execution_result event. |
| 141 | + results = message.get("results") |
| 142 | + if not results or not isinstance(results, dict): |
| 143 | + logging.warning("Got no results back from %s.", execution_id) |
| 144 | + return |
| 145 | + |
| 146 | + query.state = ExecutionState.COMPLETED |
| 147 | + query.handler(self._handle_results(execution_id, results)) |
134 | 148 | elif query.state == ExecutionState.CANCELLED:
|
135 |
| - logging.info("Query %s has been cancelled.", execution_id) |
| 149 | + logging.info( |
| 150 | + "Query %s has been cancelled; returning empty results.", |
| 151 | + execution_id, |
| 152 | + ) |
| 153 | + query.handler(pandas.DataFrame()) |
136 | 154 | self.__queries.pop(execution_id)
|
137 | 155 | elif query.state == ExecutionState.FAILED:
|
138 | 156 | # Don't do anything here; the ERROR event is coming with more
|
139 | 157 | # details.
|
140 | 158 | pass
|
141 |
| - |
142 |
| - elif kind == EventKind.EXECUTION_RESULT: |
143 |
| - results = message.get("results") |
144 |
| - if not results or not isinstance(results, dict): |
145 |
| - logging.warning("Got no results back from %s.", execution_id) |
146 |
| - return |
147 |
| - |
148 |
| - result_bytes = results.get("result_bytes") |
149 |
| - result_format = results.get("format") |
150 |
| - result_compression = results.get("compression") |
151 |
| - logging.info( |
152 |
| - "Received %d bytes of %s-compressed %s results from %s.", |
153 |
| - len(result_bytes), |
154 |
| - result_compression, |
155 |
| - result_format, |
156 |
| - execution_id, |
157 |
| - ) |
158 |
| - |
159 |
| - query.state = ExecutionState.COMPLETED |
160 |
| - if result_format == ResultsFormat.JSON: |
161 |
| - query.handler(json.loads(result_bytes.decode("utf-8"))) |
162 |
| - elif result_format == ResultsFormat.ARROW: |
163 |
| - buffer = pyarrow.py_buffer(result_bytes) |
164 |
| - stream = pyarrow.input_stream(buffer, result_compression) |
165 |
| - with pyarrow.ipc.open_stream(stream) as reader: |
166 |
| - query.handler(reader.read_pandas()) |
167 |
| - else: |
168 |
| - query.handler( |
169 |
| - OperationalError(f"Unsupported results format {result_format}") |
170 |
| - ) |
171 | 159 | elif kind == EventKind.ERROR:
|
172 | 160 | query.state = ExecutionState.FAILED
|
173 | 161 | error = message.get("message")
|
174 | 162 | query.handler(OperationalError(error))
|
175 | 163 | else:
|
176 | 164 | logging.warning("Received unknown %s event!", kind)
|
177 | 165 |
|
| 166 | + def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any: |
| 167 | + result_bytes = results.get("result_bytes") |
| 168 | + result_format = results.get("format") |
| 169 | + result_compression = results.get("compression") |
| 170 | + logging.info( |
| 171 | + "Received %d bytes of %s-compressed %s results from %s.", |
| 172 | + len(result_bytes), |
| 173 | + result_compression, |
| 174 | + result_format, |
| 175 | + execution_id, |
| 176 | + ) |
| 177 | + |
| 178 | + if result_format == ResultsFormat.JSON: |
| 179 | + return json.loads(result_bytes.decode("utf-8")) |
| 180 | + elif result_format == ResultsFormat.ARROW: |
| 181 | + buffer = pyarrow.py_buffer(result_bytes) |
| 182 | + stream = pyarrow.input_stream(buffer, result_compression) |
| 183 | + with pyarrow.ipc.open_stream(stream) as reader: |
| 184 | + return reader.read_pandas() |
| 185 | + else: |
| 186 | + return OperationalError(f"Unsupported results format {result_format}") |
| 187 | + |
178 | 188 | def __send(self, message: dict[str, Any]) -> None:
|
179 | 189 | request = json.dumps(message)
|
180 | 190 | logging.debug("Request: %s", request)
|
|
0 commit comments