diff --git a/airflow-core/newsfragments/66410.feature.rst b/airflow-core/newsfragments/66410.feature.rst new file mode 100644 index 0000000000000..9faebc476ff62 --- /dev/null +++ b/airflow-core/newsfragments/66410.feature.rst @@ -0,0 +1 @@ +Add the ``on_task_instance_checkpointed`` listener hook so listener authors observe checkpoint events alongside running, success, failed, and skipped. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 2983263bbc59b..bdf69e69be527 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3546,6 +3546,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 5ed96855c24ac..64a120d893c37 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -14445,6 +14445,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/example_dags/plugins/event_listener.py b/airflow-core/src/airflow/example_dags/plugins/event_listener.py index 91af9f5ccc6df..252ce2b65f344 100644 --- a/airflow-core/src/airflow/example_dags/plugins/event_listener.py +++ b/airflow-core/src/airflow/example_dags/plugins/event_listener.py @@ -173,6 +173,26 @@ def on_task_instance_skipped( # [END howto_listen_ti_skipped_task] +# [START howto_listen_ti_checkpointed_task] +@hookimpl +def on_task_instance_checkpointed( + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, + checkpoint_data, +): + """ + Called when a task reaches a stable checkpoint and pauses. + + The operator raised ``AirflowTaskCheckpointed`` from ``execute()`` and the + worker reported ``CHECKPOINTED`` state. ``checkpoint_data`` is the payload + the operator passed when raising; it can be ``None``. + """ + print(f"Task instance reached a checkpoint (data={checkpoint_data!r})") + + +# [END howto_listen_ti_checkpointed_task] + + # [START howto_listen_dagrun_success_task] @hookimpl def on_dag_run_success(dag_run: DagRun, msg: str): diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index 332efb105533d..890043f094656 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -56,6 +56,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" def __str__(self) -> str: return self.value @@ -87,6 +88,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = TerminalTIState.UPSTREAM_FAILED # One or more upstream deps failed SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on a trigger + CHECKPOINTED = IntermediateTIState.CHECKPOINTED # Operator paused at a stable checkpoint def __str__(self) -> str: return self.value @@ -130,6 +132,7 @@ class State: UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED SKIPPED = TaskInstanceState.SKIPPED DEFERRED = TaskInstanceState.DEFERRED + CHECKPOINTED = TaskInstanceState.CHECKPOINTED finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED]) unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING]) @@ -157,6 +160,7 @@ class State: TaskInstanceState.REMOVED: "lightgrey", TaskInstanceState.SCHEDULED: "tan", TaskInstanceState.DEFERRED: "mediumpurple", + TaskInstanceState.CHECKPOINTED: "lightyellow", } @classmethod @@ -200,6 +204,7 @@ def color_fg(cls, state): TaskInstanceState.UP_FOR_RETRY, TaskInstanceState.UP_FOR_RESCHEDULE, TaskInstanceState.DEFERRED, + TaskInstanceState.CHECKPOINTED, ] ) """ diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 276c8699de058..d549d78711481 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -841,6 +841,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskInstancesBatchBody(BaseModel): diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 01c8149d1dad8..e894eda0837dd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1274,6 +1274,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py index d3450d6b05aa7..f524b12226169 100644 --- a/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py +++ b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py @@ -79,3 +79,25 @@ def on_task_instance_skipped( :param task_instance: The task instance object (RuntimeTaskInstance when called from task execution context, TaskInstance when called from API server) """ + + +@hookspec +def on_task_instance_checkpointed( + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, + checkpoint_data, +): + """ + Execute when a task instance reaches a stable checkpoint and pauses. + + Fires when an operator raises ``AirflowTaskCheckpointed`` and the worker + reports ``CHECKPOINTED`` state. + + :param previous_state: Previous state of the task instance (typically RUNNING) + :param task_instance: The task instance object (RuntimeTaskInstance when + called from task execution context, TaskInstance when called from API + server) + :param checkpoint_data: The serializable payload the operator passed to + ``AirflowTaskCheckpointed(checkpoint_data=...)``, or ``None`` if no + payload was supplied. + """ diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..306cd386c38ac 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -198,6 +198,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class PrevSuccessfulDagRunResponse(BaseModel): @@ -365,6 +366,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskStatePutBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 7d42dad5d8502..cbde73db11224 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -165,6 +165,33 @@ class AirflowSkipException(AirflowException): """Raise when the task should be skipped.""" +class AirflowTaskCheckpointed(AirflowException): + """ + Raise when the operator has reached a stable checkpoint and intends to pause. + + The worker reports the task as CHECKPOINTED and persists the optional + ``checkpoint_data`` so a subsequent run can resume from the same point. + + This is the AIP-96 (Resumable Operators) foundation primitive. Auto-resume + semantics, scheduler treatment, and downstream trigger-rule integration are + intentionally deferred to follow-ups so the API shape can be discussed + without committing to a single resumption policy. + + :param checkpoint_data: Arbitrary serializable payload representing the + operator's resume point. Persistence and resume wiring are out of scope + for this foundation; the parameter exists so listeners and operator + authors can iterate against the final shape. + """ + + def __init__(self, checkpoint_data=None): + super().__init__() + self.checkpoint_data = checkpoint_data + + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (), {"checkpoint_data": self.checkpoint_data} + + class AirflowTaskTerminated(BaseException): """Raise when the task execution is terminated.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..ded750db1840c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -754,6 +754,7 @@ class TaskState(BaseModel): TaskInstanceState.FAILED, TaskInstanceState.SKIPPED, TaskInstanceState.REMOVED, + TaskInstanceState.CHECKPOINTED, ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..af286310964c3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1248,6 +1248,7 @@ def run( AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, + AirflowTaskCheckpointed, AirflowTaskTerminated, DagRunTriggerException, DownstreamTasksSkipped, @@ -1342,6 +1343,15 @@ def _on_term(signum, frame): rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.SKIPPED + except AirflowTaskCheckpointed as checkpoint: + log.info("Task checkpointed; reporting CHECKPOINTED state.") + msg = TaskState( + state=TaskInstanceState.CHECKPOINTED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) + state = TaskInstanceState.CHECKPOINTED + ti._checkpoint_data = checkpoint.checkpoint_data # type: ignore[attr-defined] except AirflowRescheduleException as reschedule: log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") msg = RescheduleTask( @@ -1932,6 +1942,15 @@ def finalize( log.exception("error calling listener") if error and task.email_on_failure and task.email: _send_error_email_notification(task, ti, context, error, log) + elif state == TaskInstanceState.CHECKPOINTED: + try: + get_listener_manager().hook.on_task_instance_checkpointed( + previous_state=TaskInstanceState.RUNNING, + task_instance=ti, + checkpoint_data=getattr(ti, "_checkpoint_data", None), + ) + except Exception: + log.exception("error calling listener") try: get_listener_manager().hook.before_stopping(component=TaskRunnerMarker()) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..84894b2bec19b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4758,6 +4758,69 @@ def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only( mock_supervisor_comms.send.assert_any_call(msg) +class TestTaskCheckpointed: + @pytest.mark.parametrize( + "checkpoint_data", + [ + pytest.param(None, id="no-payload"), + pytest.param({"step": 3, "iterator_offset": 1024}, id="dict-payload"), + pytest.param([1, 2, 3], id="list-payload"), + ], + ) + def test_run_returns_checkpointed_state(self, checkpoint_data, create_runtime_ti, mock_supervisor_comms): + """``run()`` reports CHECKPOINTED when the operator raises + ``AirflowTaskCheckpointed``.""" + from airflow.sdk.exceptions import AirflowTaskCheckpointed + + def _raise_checkpointed(): + raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data) + + task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed) + ti = create_runtime_ti(task=task) + + state, _msg, _error = run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + assert state == TaskInstanceState.CHECKPOINTED + + @pytest.mark.parametrize( + "checkpoint_data", + [ + pytest.param(None, id="no-payload"), + pytest.param({"step": 7}, id="dict-payload"), + ], + ) + def test_listener_receives_checkpoint_data( + self, checkpoint_data, create_runtime_ti, mock_supervisor_comms, listener_manager + ): + """``finalize()`` invokes ``on_task_instance_checkpointed`` and forwards + the operator-supplied ``checkpoint_data``.""" + from airflow.sdk.exceptions import AirflowTaskCheckpointed + + received = {"called": 0, "data": ""} + + class CheckpointListener: + @hookimpl + def on_task_instance_checkpointed(self, previous_state, task_instance, checkpoint_data): + received["called"] += 1 + received["data"] = checkpoint_data + + listener_manager(CheckpointListener()) + + def _raise_checkpointed(): + raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data) + + task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed) + ti = create_runtime_ti(task=task) + log = mock.MagicMock() + context = ti.get_template_context() + state, _msg, _error = run(ti, context=context, log=log) + finalize(ti, state, context, log) + + assert state == TaskInstanceState.CHECKPOINTED + assert received["called"] == 1 + assert received["data"] == checkpoint_data + + class TestTaskInstanceMetrics: def test_ti_start_metric_emitted(self, create_runtime_ti, mock_supervisor_comms): """Test that ti.start metric is emitted at the beginning of task."""