diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 73884433af5aa..0eaf5693470bf 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -54,6 +54,7 @@ XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" +XCOM_DAG_ID = "trigger_dag_id" if TYPE_CHECKING: @@ -86,21 +87,28 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: if TYPE_CHECKING: assert isinstance(operator, TriggerDagRunOperator) - trigger_dag_id = operator.trigger_dag_id - if not AIRFLOW_V_3_0_PLUS: - from airflow.models.renderedtifields import RenderedTaskInstanceFields - from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey - - core_ti_key = CoreTaskInstanceKey( - dag_id=ti_key.dag_id, - task_id=ti_key.task_id, - run_id=ti_key.run_id, - try_number=ti_key.try_number, - map_index=ti_key.map_index, - ) + # Try to get the resolved dag_id from XCom first (for dynamic dag_ids). + trigger_dag_id = XCom.get_value(ti_key=ti_key, key=XCOM_DAG_ID) + + # Fallback to operator attribute and rendered fields if not in XCom. + # Use identity check (`is None`) instead of truthiness so that any explicit + # value pushed to XCom (including unusual strings) is honored. + if trigger_dag_id is None: + trigger_dag_id = operator.trigger_dag_id + if not AIRFLOW_V_3_0_PLUS: + from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey + + core_ti_key = CoreTaskInstanceKey( + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + try_number=ti_key.try_number, + map_index=ti_key.map_index, + ) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): - trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] + if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): + trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] # Fetch the correct dag_run_id for the triggerED dag which is # stored in xcom during execution of the triggerING task. @@ -356,10 +364,11 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date): raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") - # Store the run id from the dag run (either created or found above) to - # be used when creating the extra link on the webserver. + # Store the resolved dag_id and run id from the dag run (either created or + # found above) to be used when creating the extra link on the webserver. ti = context["task_instance"] ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + ti.xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) if self.wait_for_completion: # Kick off the deferral process diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index bcc5ad6153c6e..71dc1a5d57c40 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -178,8 +178,10 @@ def test_trigger_dagrun(self): assert task.trigger_run_id == expected_run_id # run_id is saved as attribute @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") - @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one") - def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): + @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value") + def test_extra_operator_link(self, mock_xcom_get_value, dag_maker): + from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( task_id="test_task", @@ -191,7 +193,14 @@ def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): dr = dag_maker.create_dagrun(run_id="test_run_id") ti = dr.get_task_instance(task_id=task.task_id) - mock_xcom_get_one.return_value = ti.run_id + # Return None for XCOM_DAG_ID (so get_link falls back to operator attribute) + # and the run_id for XCOM_RUN_ID. + def _get_value(ti_key, key): + if key == XCOM_RUN_ID: + return ti.run_id + return None + + mock_xcom_get_value.side_effect = _get_value link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) @@ -199,6 +208,38 @@ def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id" assert link == expected_url, f"Expected {expected_url}, but got {link}" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker): + """Operator link uses the resolved ``dag_id`` from XCom when available.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + # In a real scenario this would be a template like "{{ ti.xcom_pull(...) }}". + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + + dr = dag_maker.create_dagrun(run_id="test_run_id") + ti = dr.get_task_instance(task_id=task.task_id) + + def _get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "dynamic_run_id" + return None + + mock_xcom_get_value.side_effect = _get_value + + link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) + + base_url = conf.get("api", "base_url", fallback="/").lower() + expected_url = f"{base_url}dags/dynamic_dag_id/runs/dynamic_run_id" + assert link == expected_url, f"Expected {expected_url}, but got {link}" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -666,6 +707,47 @@ def test_extra_operator_link(self, dag_maker, session): } assert expected_args in args + @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker, session): + """Operator link uses the resolved ``dag_id`` from XCom when available (AF2).""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + triggering_ti = session.scalar( + select(TaskInstance).where( + TaskInstance.task_id == task.task_id, TaskInstance.dag_id == task.dag_id + ) + ) + assert triggering_ti is not None + + def _get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = _get_value + + with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url: + task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key) + assert mock_build_url.called + args, _ = mock_build_url.call_args + # Should use the dag_id resolved from XCom, not the operator attribute. + expected_args = { + "dag_id": "dynamic_dag_id", + "dag_run_id": "test_run_id", + } + assert expected_args in args + def test_trigger_dagrun_with_logical_date(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date.""" custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 7c318fc499ed6..3b5bc3934bbdc 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1578,8 +1578,11 @@ def _handle_trigger_dag_run( log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) - # Store the run id from the dag run (either created or found above) to - # be used when creating the extra link on the webserver. + # Store the resolved dag_id and run id from the dag run (either created or found + # above) to be used when creating the extra link on the webserver. Pushed here + # (after a successful trigger) so the XCom state stays consistent with + # ``trigger_run_id`` and is not written when the trigger fails. + ti.xcom_push(key="trigger_dag_id", value=drte.trigger_dag_id) ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) if drte.wait_for_completion: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 723ca42d93aa6..7959da8a9b14d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4538,6 +4538,16 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), ), + mock.call.send( + msg=SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run", + task_id="test_task", + run_id="test_run", + map_index=-1, + ), + ), mock.call.send( msg=SetXCom( key="trigger_run_id", @@ -4638,7 +4648,9 @@ def test_handle_trigger_dag_run_wait_for_completion( None, # Successful Dag Run trigger OKResponse(ok=True), - # Set XCOM, + # Set XCOM trigger_dag_id + None, + # Set XCOM trigger_run_id None, # Dag Run is still running DagRunStateResult(state=DagRunState.RUNNING), @@ -4661,6 +4673,16 @@ def test_handle_trigger_dag_run_wait_for_completion( logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), ), + mock.call.send( + msg=SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run_wait_for_completion", + task_id="test_task", + run_id="test_run", + map_index=-1, + ), + ), mock.call.send( msg=SetXCom( key="trigger_run_id",