@@ -1267,7 +1267,7 @@ def test_xcom_pull(
1267
1267
):
1268
1268
"""
1269
1269
Test that a task makes an expected call to the Supervisor to pull XCom values
1270
- based on various task_ids and map_indexes configurations.
1270
+ based on various task_ids, map_indexes, and xcom_values configurations.
1271
1271
"""
1272
1272
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes" : map_indexes }
1273
1273
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids" : task_ids }
@@ -1313,6 +1313,54 @@ def execute(self, context):
1313
1313
),
1314
1314
)
1315
1315
1316
+ @pytest .mark .parametrize (
1317
+ "task_ids, map_indexes, expected_value" ,
1318
+ [
1319
+ pytest .param ("task_a" , 0 , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is int" ),
1320
+ pytest .param ("task_a" , [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is str, map_index is list" ),
1321
+ pytest .param ("task_a" , None , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is None" ),
1322
+ pytest .param ("task_a" , NOTSET , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is ArgNotSet" ),
1323
+ pytest .param (["task_a" ], 0 , [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is int" ),
1324
+ pytest .param (["task_a" ], [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is list" ),
1325
+ pytest .param (["task_a" ], None , [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is None" ),
1326
+ pytest .param (
1327
+ ["task_a" ], NOTSET , [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is ArgNotSet"
1328
+ ),
1329
+ pytest .param (None , 0 , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is int" ),
1330
+ pytest .param (None , [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is None, map_index is list" ),
1331
+ pytest .param (None , None , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is None" ),
1332
+ pytest .param (None , NOTSET , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is ArgNotSet" ),
1333
+ ],
1334
+ )
1335
+ def test_xcom_pull_return_values (
1336
+ self ,
1337
+ create_runtime_ti ,
1338
+ mock_supervisor_comms ,
1339
+ task_ids ,
1340
+ map_indexes ,
1341
+ expected_value ,
1342
+ ):
1343
+ """
1344
+ Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
1345
+ The above test covers the expected calls to supervisor comms.
1346
+ """
1347
+
1348
+ class CustomOperator (BaseOperator ):
1349
+ def execute (self , context ):
1350
+ print ("This is a custom operator" )
1351
+
1352
+ test_task_id = "pull_task"
1353
+ task = CustomOperator (task_id = test_task_id )
1354
+ runtime_ti = create_runtime_ti (task = task )
1355
+
1356
+ value = {"a" : 1 , "b" : 2 }
1357
+ # API server returns serialised value for xcom result, staging it in that way
1358
+ xcom_value = BaseXCom .serialize_value (value )
1359
+ mock_supervisor_comms .get_message .return_value = XComResult (key = "key" , value = xcom_value )
1360
+
1361
+ returned_xcom = runtime_ti .xcom_pull (key = "key" , task_ids = task_ids , map_indexes = map_indexes )
1362
+ assert returned_xcom == expected_value
1363
+
1316
1364
def test_get_param_from_context (
1317
1365
self , mocked_parse , make_ti_context , mock_supervisor_comms , create_runtime_ti
1318
1366
):
0 commit comments