diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 4b27314e0eb3d..84b516ad13b39 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1684,6 +1684,7 @@ def start( # type: ignore[override] log = structlog.get_logger(logger_name="task") state, msg, error = run(ti, context, log) + context["exception"] = error finalize(ti, state, context, log, error) # In the normal subprocess model, the task runner calls this before exiting. diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 3ab95e0810e09..da6d41c9f99f4 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -49,7 +49,7 @@ from uuid6 import uuid7 from airflow.executors.workloads import BundleInfo -from airflow.sdk import BaseOperator, timezone +from airflow.sdk import DAG, BaseOperator, timezone from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import ( @@ -2807,6 +2807,56 @@ def _handle_request(self, msg, log, req_id): assert isinstance(response, VariableResult) assert response.value == "value" + def test_inprocess_failure_callback_receives_exception( + self, + monkeypatch, + make_ti_context, + ): + """Run a failing task via InProcessTestSupervisor and ensure the + `on_failure_callback` receives `context['exception']`. + """ + collected: list[BaseException | None] = [None] + + class _Failure(Exception): + pass + + def failure_callback(context): + collected[0] = context.get("exception") + + class FailingOperator(BaseOperator): + def execute(self, context=None): + raise _Failure("boom") + + task = FailingOperator(task_id="failing", on_failure_callback=failure_callback) + + # Assign a minimal DAG to the operator so `task.dag` access succeeds + task.dag = DAG(dag_id="test_dag") + + # Create a simple TaskInstance datamodel to pass to the supervisor + ti = TaskInstance( + id=uuid7(), + task_id=task.task_id, + dag_id="test_dag", + run_id="r", + try_number=1, + dag_version_id=uuid7(), + ) + + # Patch the API client used by InProcessTestSupervisor to return a predictable TI context + fake_task_instances = mock.MagicMock(spec_set=["start", "finish"]) + fake_task_instances.start.return_value = make_ti_context() + fake_client = mock.MagicMock(spec_set=["task_instances"]) + fake_client.task_instances = fake_task_instances + monkeypatch.setattr( + InProcessTestSupervisor, "_api_client", staticmethod(lambda dag=None: fake_client) + ) + + result = InProcessTestSupervisor.start(what=ti, task=task) + + # Ensure the task failed and the callback saw the exception + assert isinstance(result.error, _Failure) + assert isinstance(collected[0], _Failure) + class TestInProcessClient: def test_no_retries(self):