Skip to content

Commit 6dc949f

Browse files
amoghrajeshkaxil
authored andcommitted
Treat single task_ids in xcom_pull the same as multiple (apache#49692)
* Treat single task_ids in xcom_pull the same as multiple closes: apache#49540 * fixup! Treat single task_ids in xcom_pull the same as multiple --------- Co-authored-by: Kaxil Naik <[email protected]>
1 parent c3d3b4c commit 6dc949f

File tree

4 files changed

+56
-34
lines changed

4 files changed

+56
-34
lines changed

RELEASE_NOTES.rst

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,6 @@ aligning with the broader asset-aware execution model introduced in Airflow 3.0.
336336
Behaviour change in ``xcom_pull``
337337
"""""""""""""""""""""""""""""""""
338338

339-
**Pulling without setting ``task_ids``**:
340-
341339
In Airflow 2, the ``xcom_pull()`` method allowed pulling XComs by key without specifying task_ids, despite the fact that the underlying
342340
DB model defines task_id as part of the XCom primary key. This created ambiguity: if two tasks pushed XComs with the same key,
343341
``xcom_pull()`` would pull whichever one happened to be first, leading to unpredictable behavior.
@@ -354,34 +352,6 @@ Should be updated to::
354352
kwargs["ti"].xcom_pull(task_ids="task1", key="key")
355353

356354

357-
**Return Type Change for Single Task ID**:
358-
359-
In Airflow 2, when using ``xcom_pull()`` with a single task ID in a list (e.g., ``task_ids=["task1"]``), it would return a ``LazyXComSelectSequence``
360-
object containing one value. In Airflow 3.0.0, this behavior was changed to return the value directly.
361-
362-
So, if you previously used:
363-
364-
.. code-block:: python
365-
366-
xcom_values = kwargs["ti"].xcom_pull(task_ids=["task1"], key="key")
367-
xcom_value = xcom_values[0] # Access the first value
368-
369-
You would now get the value directly, rather than a sequence containing one value.
370-
371-
.. code-block:: python
372-
373-
xcom_value = kwargs["ti"].xcom_pull(task_ids=["task1"], key="key")
374-
375-
The previous behaviour (returning list when passed a list) will be restored in Airflow 3.0.1 to maintain backward compatibility.
376-
377-
However, it is recommended to be explicit about your intentions when using ``task_ids`` (after the fix in 3.0.1):
378-
379-
- If you want a single value, use ``task_ids="task1"``
380-
- If you want a sequence, use ``task_ids=["task1"]``
381-
382-
This makes the code more explicit and easier to understand.
383-
384-
385355
Removed Configuration Keys
386356
"""""""""""""""""""""""""""
387357

reproducible_build.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
release-notes-hash: 77a6fba681cf21973ca9712136d1b51a
2-
source-date-epoch: 1745327923
1+
release-notes-hash: df3b67b987fd909d16f6158df78b3813
2+
source-date-epoch: 1745660315

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ def xcom_pull(
325325
if run_id is None:
326326
run_id = self.run_id
327327

328+
single_task_requested = isinstance(task_ids, (str, type(None)))
329+
single_map_index_requested = isinstance(map_indexes, (int, type(None), ArgNotSet))
330+
328331
if task_ids is None:
329332
# default to the current task if not provided
330333
task_ids = [self.task_id]
@@ -363,8 +366,9 @@ def xcom_pull(
363366
else:
364367
xcoms.append(value)
365368

366-
if len(xcoms) == 1:
369+
if single_task_requested and single_map_index_requested:
367370
return xcoms[0]
371+
368372
return xcoms
369373

370374
def xcom_push(self, key: str, value: Any):

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ def test_xcom_pull(
12671267
):
12681268
"""
12691269
Test that a task makes an expected call to the Supervisor to pull XCom values
1270-
based on various task_ids and map_indexes configurations.
1270+
based on various task_ids, map_indexes, and xcom_values configurations.
12711271
"""
12721272
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes": map_indexes}
12731273
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids}
@@ -1313,6 +1313,54 @@ def execute(self, context):
13131313
),
13141314
)
13151315

1316+
@pytest.mark.parametrize(
1317+
"task_ids, map_indexes, expected_value",
1318+
[
1319+
pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"),
1320+
pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"),
1321+
pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"),
1322+
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
1323+
pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"),
1324+
pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"),
1325+
pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"),
1326+
pytest.param(
1327+
["task_a"], NOTSET, [{"a": 1, "b": 2}], id="task_id is list, map_index is ArgNotSet"
1328+
),
1329+
pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"),
1330+
pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"),
1331+
pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"),
1332+
pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"),
1333+
],
1334+
)
1335+
def test_xcom_pull_return_values(
1336+
self,
1337+
create_runtime_ti,
1338+
mock_supervisor_comms,
1339+
task_ids,
1340+
map_indexes,
1341+
expected_value,
1342+
):
1343+
"""
1344+
Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
1345+
The above test covers the expected calls to supervisor comms.
1346+
"""
1347+
1348+
class CustomOperator(BaseOperator):
1349+
def execute(self, context):
1350+
print("This is a custom operator")
1351+
1352+
test_task_id = "pull_task"
1353+
task = CustomOperator(task_id=test_task_id)
1354+
runtime_ti = create_runtime_ti(task=task)
1355+
1356+
value = {"a": 1, "b": 2}
1357+
# API server returns serialised value for xcom result, staging it in that way
1358+
xcom_value = BaseXCom.serialize_value(value)
1359+
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value)
1360+
1361+
returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
1362+
assert returned_xcom == expected_value
1363+
13161364
def test_get_param_from_context(
13171365
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
13181366
):

0 commit comments

Comments
 (0)