diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index b34ab12dd4aef..1d02acdebcfdb 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -1718,6 +1718,39 @@ def check_task_2(my_input): mock_task_object_1.assert_called() mock_task_object_2.assert_not_called() + def test_dag_test_failure_callback_receives_exception_in_context(self, testing_dag_bundle): + captured: dict = {} + + def handle_task_failure(context): + captured["exception"] = context.get("exception") + captured["task_id"] = context["task_instance"].task_id + + dag = DAG( + dag_id="test_dag_test_failure_ctx_exception", + default_args={"on_failure_callback": handle_task_failure}, + start_date=DEFAULT_DATE, + schedule=None, + ) + sync_dag_to_db(dag) + + @task_decorator + def boom(): + raise AirflowException("boooom") + + with dag: + boom() + sync_dag_to_db(dag) + + dr = dag.test() + ti = dr.get_task_instance("boom") + assert ti is not None + assert ti.state == TaskInstanceState.FAILED + # The in-process supervisor must mirror task_runner.main(): expose the raised + # exception in context["exception"] so on_failure_callback can use it. + assert captured.get("task_id") == "boom" + assert isinstance(captured.get("exception"), AirflowException) + assert str(captured["exception"]) == "boooom" + def test_dag_connection_file(self, tmp_path, testing_dag_bundle): test_connections_string = """ --- diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 375c5a9e30b8e..7fb246e0fec23 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1804,6 +1804,9 @@ def start( # type: ignore[override] log = structlog.get_logger(logger_name="task") state, msg, error = run(ti, context, log) + # Mirror the subprocess path in `task_runner.main()`: expose the raised + # exception via `context["exception"]` so failure/retry callbacks see it. + context["exception"] = error finalize(ti, state, context, log, error) # In the normal subprocess model, the task runner calls this before exiting. diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 3695af1fff592..5a310961f32ab 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2901,6 +2901,78 @@ def _handle_request(self, msg, log, req_id): assert isinstance(response, VariableResult) assert response.value == "value" + @pytest.mark.parametrize( + "task_state", + [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY], + ) + def test_start_sets_exception_in_context_before_finalize(self, mocker, make_ti_context, task_state): + """ + Regression test for the in-process dag.test() path. + + `task_runner.main()` (the subprocess path) sets ``context["exception"] = error`` + immediately after ``run()`` and before ``finalize()`` so that + ``on_failure_callback`` and ``on_retry_callback`` can read the raised + exception from context. The in-process supervisor used by ``dag.test()`` + must mirror that contract. + """ + raised = ValueError("kaboom") + captured: dict[str, Any] = {} + + def fake_finalize(_ti, state, context, _log, error=None): + captured["state"] = state + captured["error"] = error + captured["exception_in_context"] = context.get("exception") + + mocker.patch( + "airflow.sdk.execution_time.task_runner.run", + return_value=(task_state, MagicMock(), raised), + ) + mocker.patch( + "airflow.sdk.execution_time.task_runner.finalize", + side_effect=fake_finalize, + ) + + fake_runtime_ti = MagicMock() + fake_runtime_ti.get_template_context.return_value = {} + mocker.patch( + "airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.model_construct", + return_value=fake_runtime_ti, + ) + + # No real sockets / no real terminal-state notification needed for this test. + mocker.patch.object( + InProcessTestSupervisor, + "_setup_subprocess_socket", + return_value=mock.MagicMock(), + ) + mocker.patch.object(InProcessTestSupervisor, "update_task_state_if_needed") + + mock_client = MagicMock(spec=sdk_client.Client) + mock_client.task_instances.start.return_value = make_ti_context() + mocker.patch.object(InProcessTestSupervisor, "_api_client", return_value=mock_client) + + ti = TaskInstance( + id=uuid7(), + task_id="t", + dag_id="d", + run_id="r", + try_number=1, + dag_version_id=uuid7(), + ) + task = MagicMock() + task.dag = MagicMock() + + result = InProcessTestSupervisor.start(what=ti, task=task) + + assert result.state == task_state + assert result.error is raised + assert captured["state"] == task_state + assert captured["error"] is raised + # The bug being fixed: without the in-process supervisor mirroring the + # subprocess path, this would be missing/None and failure or retry + # callbacks couldn't read the exception from context. + assert captured["exception_in_context"] is raised + class TestInProcessClient: def test_no_retries(self):