Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions airflow-core/tests/unit/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,39 @@ def check_task_2(my_input):
mock_task_object_1.assert_called()
mock_task_object_2.assert_not_called()

def test_dag_test_failure_callback_receives_exception_in_context(self, testing_dag_bundle):
captured: dict = {}

def handle_task_failure(context):
captured["exception"] = context.get("exception")
captured["task_id"] = context["task_instance"].task_id

dag = DAG(
dag_id="test_dag_test_failure_ctx_exception",
default_args={"on_failure_callback": handle_task_failure},
start_date=DEFAULT_DATE,
schedule=None,
)
sync_dag_to_db(dag)

@task_decorator
def boom():
raise AirflowException("boooom")

with dag:
boom()
sync_dag_to_db(dag)

dr = dag.test()
ti = dr.get_task_instance("boom")
assert ti is not None
assert ti.state == TaskInstanceState.FAILED
# The in-process supervisor must mirror task_runner.main(): expose the raised
# exception in context["exception"] so on_failure_callback can use it.
assert captured.get("task_id") == "boom"
assert isinstance(captured.get("exception"), AirflowException)
assert str(captured["exception"]) == "boooom"

def test_dag_connection_file(self, tmp_path, testing_dag_bundle):
test_connections_string = """
---
Expand Down
3 changes: 3 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,9 @@ def start( # type: ignore[override]
log = structlog.get_logger(logger_name="task")

state, msg, error = run(ti, context, log)
# Mirror the subprocess path in `task_runner.main()`: expose the raised
# exception via `context["exception"]` so failure/retry callbacks see it.
context["exception"] = error
finalize(ti, state, context, log, error)

# In the normal subprocess model, the task runner calls this before exiting.
Expand Down
72 changes: 72 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,6 +2901,78 @@ def _handle_request(self, msg, log, req_id):
assert isinstance(response, VariableResult)
assert response.value == "value"

@pytest.mark.parametrize(
"task_state",
[TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY],
)
def test_start_sets_exception_in_context_before_finalize(self, mocker, make_ti_context, task_state):
"""
Regression test for the in-process dag.test() path.

`task_runner.main()` (the subprocess path) sets ``context["exception"] = error``
immediately after ``run()`` and before ``finalize()`` so that
``on_failure_callback`` and ``on_retry_callback`` can read the raised
exception from context. The in-process supervisor used by ``dag.test()``
must mirror that contract.
"""
raised = ValueError("kaboom")
captured: dict[str, Any] = {}

def fake_finalize(_ti, state, context, _log, error=None):
captured["state"] = state
captured["error"] = error
captured["exception_in_context"] = context.get("exception")

mocker.patch(
"airflow.sdk.execution_time.task_runner.run",
return_value=(task_state, MagicMock(), raised),
)
mocker.patch(
"airflow.sdk.execution_time.task_runner.finalize",
side_effect=fake_finalize,
)

fake_runtime_ti = MagicMock()
fake_runtime_ti.get_template_context.return_value = {}
mocker.patch(
"airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.model_construct",
return_value=fake_runtime_ti,
)

# No real sockets / no real terminal-state notification needed for this test.
mocker.patch.object(
InProcessTestSupervisor,
"_setup_subprocess_socket",
return_value=mock.MagicMock(),
)
mocker.patch.object(InProcessTestSupervisor, "update_task_state_if_needed")

mock_client = MagicMock(spec=sdk_client.Client)
mock_client.task_instances.start.return_value = make_ti_context()
mocker.patch.object(InProcessTestSupervisor, "_api_client", return_value=mock_client)

ti = TaskInstance(
id=uuid7(),
task_id="t",
dag_id="d",
run_id="r",
try_number=1,
dag_version_id=uuid7(),
)
task = MagicMock()
task.dag = MagicMock()

result = InProcessTestSupervisor.start(what=ti, task=task)

assert result.state == task_state
assert result.error is raised
assert captured["state"] == task_state
assert captured["error"] is raised
# The bug being fixed: without the in-process supervisor mirroring the
# subprocess path, this would be missing/None and failure or retry
# callbacks couldn't read the exception from context.
assert captured["exception_in_context"] is raised


class TestInProcessClient:
def test_no_retries(self):
Expand Down
Loading