Skip to content

Commit f5823bc

Browse files
committed
Remove direct access to DB for safe_to_cancel() method for Dataproc and BigQuery triggers
1 parent 1ec1e68 commit f5823bc

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from airflow.exceptions import AirflowException
2727
from airflow.models.taskinstance import TaskInstance
2828
from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook
29+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
2930
from airflow.triggers.base import BaseTrigger, TriggerEvent
3031
from airflow.utils.session import provide_session
3132
from airflow.utils.state import TaskInstanceState
@@ -116,16 +117,39 @@ def get_task_instance(self, session: Session) -> TaskInstance:
116117
)
117118
return task_instance
118119

120+
def get_task_state(self):
121+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
122+
123+
task_states_response = RuntimeTaskInstance.get_task_states(
124+
dag_id=self.task_instance.dag_id,
125+
task_ids=[self.task_instance.task_id],
126+
run_ids=[self.task_instance.run_id],
127+
)
128+
try:
129+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
130+
except Exception:
131+
raise AirflowException(
132+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not found",
133+
self.task_instance.dag_id,
134+
self.task_instance.task_id,
135+
self.task_instance.run_id,
136+
)
137+
return task_state
138+
119139
def safe_to_cancel(self) -> bool:
120140
"""
121141
Whether it is safe to cancel the external job which is being executed by this trigger.
122142
123143
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
124144
Because in those cases, we should NOT cancel the external job.
125145
"""
126-
# Database query is needed to get the latest state of the task instance.
127-
task_instance = self.get_task_instance() # type: ignore[call-arg]
128-
return task_instance.state != TaskInstanceState.DEFERRED
146+
if AIRFLOW_V_3_0_PLUS:
147+
task_state = self.get_task_state()
148+
else:
149+
# Database query is needed to get the latest state of the task instance.
150+
task_instance = self.get_task_instance() # type: ignore[call-arg]
151+
task_state = task_instance.state
152+
return task_state != TaskInstanceState.DEFERRED
129153

130154
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
131155
"""Get current job execution status and yields a TriggerEvent."""

providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
3434
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
3535
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
36+
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
3637
from airflow.triggers.base import BaseTrigger, TriggerEvent
3738
from airflow.utils.session import provide_session
3839
from airflow.utils.state import TaskInstanceState
@@ -141,16 +142,39 @@ def get_task_instance(self, session: Session) -> TaskInstance:
141142
)
142143
return task_instance
143144

145+
def get_task_state(self):
146+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
147+
148+
task_states_response = RuntimeTaskInstance.get_task_states(
149+
dag_id=self.task_instance.dag_id,
150+
task_ids=[self.task_instance.task_id],
151+
run_ids=[self.task_instance.run_id],
152+
)
153+
try:
154+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
155+
except Exception:
156+
raise AirflowException(
157+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not found",
158+
self.task_instance.dag_id,
159+
self.task_instance.task_id,
160+
self.task_instance.run_id,
161+
)
162+
return task_state
163+
144164
def safe_to_cancel(self) -> bool:
145165
"""
146166
Whether it is safe to cancel the external job which is being executed by this trigger.
147167
148168
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
149169
Because in those cases, we should NOT cancel the external job.
150170
"""
151-
# Database query is needed to get the latest state of the task instance.
152-
task_instance = self.get_task_instance() # type: ignore[call-arg]
153-
return task_instance.state != TaskInstanceState.DEFERRED
171+
if AIRFLOW_V_3_0_PLUS:
172+
task_state = self.get_task_state()
173+
else:
174+
# Database query is needed to get the latest state of the task instance.
175+
task_instance = self.get_task_instance() # type: ignore[call-arg]
176+
task_state = task_instance.state
177+
return task_state != TaskInstanceState.DEFERRED
154178

155179
async def run(self):
156180
try:
@@ -243,16 +267,39 @@ def get_task_instance(self, session: Session) -> TaskInstance:
243267
)
244268
return task_instance
245269

270+
def get_task_state(self):
271+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
272+
273+
task_states_response = RuntimeTaskInstance.get_task_states(
274+
dag_id=self.task_instance.dag_id,
275+
task_ids=[self.task_instance.task_id],
276+
run_ids=[self.task_instance.run_id],
277+
)
278+
try:
279+
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
280+
except Exception:
281+
raise AirflowException(
282+
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s is not found",
283+
self.task_instance.dag_id,
284+
self.task_instance.task_id,
285+
self.task_instance.run_id,
286+
)
287+
return task_state
288+
246289
def safe_to_cancel(self) -> bool:
247290
"""
248291
Whether it is safe to cancel the external job which is being executed by this trigger.
249292
250293
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
251294
Because in those cases, we should NOT cancel the external job.
252295
"""
253-
# Database query is needed to get the latest state of the task instance.
254-
task_instance = self.get_task_instance() # type: ignore[call-arg]
255-
return task_instance.state != TaskInstanceState.DEFERRED
296+
if AIRFLOW_V_3_0_PLUS:
297+
task_state = self.get_task_state()
298+
else:
299+
# Database query is needed to get the latest state of the task instance.
300+
task_instance = self.get_task_instance() # type: ignore[call-arg]
301+
task_state = task_instance.state
302+
return task_state != TaskInstanceState.DEFERRED
256303

257304
async def run(self) -> AsyncIterator[TriggerEvent]:
258305
try:

0 commit comments

Comments
 (0)