|
33 | 33 | from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
|
34 | 34 | from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
35 | 35 | from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
| 36 | +from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS |
36 | 37 | from airflow.triggers.base import BaseTrigger, TriggerEvent
|
37 | 38 | from airflow.utils.session import provide_session
|
38 | 39 | from airflow.utils.state import TaskInstanceState
|
@@ -141,16 +142,39 @@ def get_task_instance(self, session: Session) -> TaskInstance:
|
141 | 142 | )
|
142 | 143 | return task_instance
|
143 | 144 |
|
| 145 | + def get_task_state(self): |
| 146 | + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance |
| 147 | + |
| 148 | + task_states_response = RuntimeTaskInstance.get_task_states( |
| 149 | + dag_id=self.task_instance.dag_id, |
| 150 | + task_ids=[self.task_instance.task_id], |
| 151 | + run_ids=[self.task_instance.run_id], |
| 152 | + ) |
| 153 | + try: |
| 154 | + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] |
| 155 | + except Exception: |
| 156 | + raise AirflowException( |
| 157 | + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not found", |
| 158 | + self.task_instance.dag_id, |
| 159 | + self.task_instance.task_id, |
| 160 | + self.task_instance.run_id, |
| 161 | + ) |
| 162 | + return task_state |
| 163 | + |
144 | 164 | def safe_to_cancel(self) -> bool:
|
145 | 165 | """
|
146 | 166 | Whether it is safe to cancel the external job which is being executed by this trigger.
|
147 | 167 |
|
148 | 168 | This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
149 | 169 | Because in those cases, we should NOT cancel the external job.
|
150 | 170 | """
|
151 |
| - # Database query is needed to get the latest state of the task instance. |
152 |
| - task_instance = self.get_task_instance() # type: ignore[call-arg] |
153 |
| - return task_instance.state != TaskInstanceState.DEFERRED |
| 171 | + if AIRFLOW_V_3_0_PLUS: |
| 172 | + task_state = self.get_task_state() |
| 173 | + else: |
| 174 | + # Database query is needed to get the latest state of the task instance. |
| 175 | + task_instance = self.get_task_instance() # type: ignore[call-arg] |
| 176 | + task_state = task_instance.state |
| 177 | + return task_state != TaskInstanceState.DEFERRED |
154 | 178 |
|
155 | 179 | async def run(self):
|
156 | 180 | try:
|
@@ -243,16 +267,39 @@ def get_task_instance(self, session: Session) -> TaskInstance:
|
243 | 267 | )
|
244 | 268 | return task_instance
|
245 | 269 |
|
| 270 | + def get_task_state(self): |
| 271 | + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance |
| 272 | + |
| 273 | + task_states_response = RuntimeTaskInstance.get_task_states( |
| 274 | + dag_id=self.task_instance.dag_id, |
| 275 | + task_ids=[self.task_instance.task_id], |
| 276 | + run_ids=[self.task_instance.run_id], |
| 277 | + ) |
| 278 | + try: |
| 279 | + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] |
| 280 | + except Exception: |
| 281 | + raise AirflowException( |
| 282 | + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not found", |
| 283 | + self.task_instance.dag_id, |
| 284 | + self.task_instance.task_id, |
| 285 | + self.task_instance.run_id, |
| 286 | + ) |
| 287 | + return task_state |
| 288 | + |
246 | 289 | def safe_to_cancel(self) -> bool:
|
247 | 290 | """
|
248 | 291 | Whether it is safe to cancel the external job which is being executed by this trigger.
|
249 | 292 |
|
250 | 293 | This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
251 | 294 | Because in those cases, we should NOT cancel the external job.
|
252 | 295 | """
|
253 |
| - # Database query is needed to get the latest state of the task instance. |
254 |
| - task_instance = self.get_task_instance() # type: ignore[call-arg] |
255 |
| - return task_instance.state != TaskInstanceState.DEFERRED |
| 296 | + if AIRFLOW_V_3_0_PLUS: |
| 297 | + task_state = self.get_task_state() |
| 298 | + else: |
| 299 | + # Database query is needed to get the latest state of the task instance. |
| 300 | + task_instance = self.get_task_instance() # type: ignore[call-arg] |
| 301 | + task_state = task_instance.state |
| 302 | + return task_state != TaskInstanceState.DEFERRED |
256 | 303 |
|
257 | 304 | async def run(self) -> AsyncIterator[TriggerEvent]:
|
258 | 305 | try:
|
|
0 commit comments