Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"
XCOM_DAG_ID = "trigger_dag_id"


if TYPE_CHECKING:
Expand Down Expand Up @@ -86,21 +87,28 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
if TYPE_CHECKING:
assert isinstance(operator, TriggerDagRunOperator)

trigger_dag_id = operator.trigger_dag_id
if not AIRFLOW_V_3_0_PLUS:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey

core_ti_key = CoreTaskInstanceKey(
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
run_id=ti_key.run_id,
try_number=ti_key.try_number,
map_index=ti_key.map_index,
)
# Try to get the resolved dag_id from XCom first (for dynamic dag_ids).
trigger_dag_id = XCom.get_value(ti_key=ti_key, key=XCOM_DAG_ID)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: if not trigger_dag_id: uses a truthiness check. If XCom.get_value returns an empty string "" (unlikely but possible), this would incorrectly fall through to the operator attribute fallback.

Should be:
if trigger_dag_id is None:

# Fallback to operator attribute and rendered fields if not in XCom.
# Use identity check (`is None`) instead of truthiness so that any explicit
# value pushed to XCom (including unusual strings) is honored.
if trigger_dag_id is None:
trigger_dag_id = operator.trigger_dag_id
if not AIRFLOW_V_3_0_PLUS:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey

core_ti_key = CoreTaskInstanceKey(
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
run_id=ti_key.run_id,
try_number=ti_key.try_number,
map_index=ti_key.map_index,
)

if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key):
trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]
if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key):
trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef]

# Fetch the correct dag_run_id for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
Expand Down Expand Up @@ -356,10 +364,11 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
# Store the run id from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
# Store the resolved dag_id and run id from the dag run (either created or
# found above) to be used when creating the extra link on the webserver.
ti = context["task_instance"]
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)
ti.xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id)

if self.wait_for_completion:
# Kick off the deferral process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ def test_trigger_dagrun(self):
assert task.trigger_run_id == expected_run_id # run_id is saved as attribute

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
@mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one")
def test_extra_operator_link(self, mock_xcom_get_one, dag_maker):
@mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value")
def test_extra_operator_link(self, mock_xcom_get_value, dag_maker):
from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
Expand All @@ -191,14 +193,53 @@ def test_extra_operator_link(self, mock_xcom_get_one, dag_maker):
dr = dag_maker.create_dagrun(run_id="test_run_id")
ti = dr.get_task_instance(task_id=task.task_id)

mock_xcom_get_one.return_value = ti.run_id
# Return None for XCOM_DAG_ID (so get_link falls back to operator attribute)
# and the run_id for XCOM_RUN_ID.
def _get_value(ti_key, key):
if key == XCOM_RUN_ID:
return ti.run_id
return None

mock_xcom_get_value.side_effect = _get_value

link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key)

base_url = conf.get("api", "base_url", fallback="/").lower()
expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id"
assert link == expected_url, f"Expected {expected_url}, but got {link}"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
@mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value")
def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker):
"""Operator link uses the resolved ``dag_id`` from XCom when available."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
# In a real scenario this would be a template like "{{ ti.xcom_pull(...) }}".
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="test_run_id",
)

dr = dag_maker.create_dagrun(run_id="test_run_id")
ti = dr.get_task_instance(task_id=task.task_id)

def _get_value(ti_key, key):
if key == XCOM_DAG_ID:
return "dynamic_dag_id"
if key == XCOM_RUN_ID:
return "dynamic_run_id"
return None

mock_xcom_get_value.side_effect = _get_value

link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key)

base_url = conf.get("api", "base_url", fallback="/").lower()
expected_url = f"{base_url}dags/dynamic_dag_id/runs/dynamic_run_id"
assert link == expected_url, f"Expected {expected_url}, but got {link}"

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun_custom_run_id(self):
task = TriggerDagRunOperator(
Expand Down Expand Up @@ -666,6 +707,47 @@ def test_extra_operator_link(self, dag_maker, session):
}
assert expected_args in args

@mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_value")
def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker, session):
"""Operator link uses the resolved ``dag_id`` from XCom when available (AF2)."""
from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID

with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="test_run_id",
)
dag_maker.create_dagrun()
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

triggering_ti = session.scalar(
select(TaskInstance).where(
TaskInstance.task_id == task.task_id, TaskInstance.dag_id == task.dag_id
)
)
assert triggering_ti is not None

def _get_value(ti_key, key):
if key == XCOM_DAG_ID:
return "dynamic_dag_id"
if key == XCOM_RUN_ID:
return "test_run_id"
return None

mock_xcom_get_value.side_effect = _get_value

with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url:
task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key)
assert mock_build_url.called
args, _ = mock_build_url.call_args
# Should use the dag_id resolved from XCom, not the operator attribute.
expected_args = {
"dag_id": "dynamic_dag_id",
"dag_run_id": "test_run_id",
}
assert expected_args in args

def test_trigger_dagrun_with_logical_date(self, dag_maker):
"""Test TriggerDagRunOperator with custom logical_date."""
custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5)
Expand Down
7 changes: 5 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,8 +1578,11 @@ def _handle_trigger_dag_run(

log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)

# Store the run id from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
# Store the resolved dag_id and run id from the dag run (either created or found
# above) to be used when creating the extra link on the webserver. Pushed here
# (after a successful trigger) so the XCom state stays consistent with
# ``trigger_run_id`` and is not written when the trigger fails.
ti.xcom_push(key="trigger_dag_id", value=drte.trigger_dag_id)
ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)

if drte.wait_for_completion:
Expand Down
24 changes: 23 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4538,6 +4538,16 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
),
),
mock.call.send(
msg=SetXCom(
key="trigger_dag_id",
value="test_dag",
dag_id="test_handle_trigger_dag_run",
task_id="test_task",
run_id="test_run",
map_index=-1,
),
),
mock.call.send(
msg=SetXCom(
key="trigger_run_id",
Expand Down Expand Up @@ -4638,7 +4648,9 @@ def test_handle_trigger_dag_run_wait_for_completion(
None,
# Successful Dag Run trigger
OKResponse(ok=True),
# Set XCOM,
# Set XCOM trigger_dag_id
None,
# Set XCOM trigger_run_id
None,
# Dag Run is still running
DagRunStateResult(state=DagRunState.RUNNING),
Expand All @@ -4661,6 +4673,16 @@ def test_handle_trigger_dag_run_wait_for_completion(
logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
),
),
mock.call.send(
msg=SetXCom(
key="trigger_dag_id",
value="test_dag",
dag_id="test_handle_trigger_dag_run_wait_for_completion",
task_id="test_task",
run_id="test_run",
map_index=-1,
),
),
mock.call.send(
msg=SetXCom(
key="trigger_run_id",
Expand Down
Loading