diff --git a/airflow-core/newsfragments/66402.significant.rst b/airflow-core/newsfragments/66402.significant.rst new file mode 100644 index 0000000000000..7829297e3bedb --- /dev/null +++ b/airflow-core/newsfragments/66402.significant.rst @@ -0,0 +1,24 @@ +Add the ``CHECKPOINTED`` task instance state and the ``AirflowTaskCheckpointed`` exception (AIP-96 foundation). + +An operator can raise ``AirflowTaskCheckpointed(checkpoint_data=...)`` from +``execute()`` to signal "I have reached a stable checkpoint and intend to +pause". The worker catches the exception and reports the ``CHECKPOINTED`` +state. ``CHECKPOINTED`` is an intermediate state and is included in +``State.unfinished``. + +.. code-block:: python + + from airflow.sdk.exceptions import AirflowTaskCheckpointed + + class ResumablePythonOperator(PythonOperator): + def execute(self, context): + for step in range(self.total_steps): + if self.should_pause(): + raise AirflowTaskCheckpointed(checkpoint_data={"step": step}) + self.do_work(step) + +This change intentionally ships only the vocabulary plus the worker-side +state transition. ``checkpoint_data`` persistence, scheduler auto-resume +semantics, the listener hook, and downstream trigger-rule integration are +deferred to follow-up PRs so the API shape can be discussed against a real +artifact before committing to a single resumption policy. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 2983263bbc59b..bdf69e69be527 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3546,6 +3546,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 5ed96855c24ac..64a120d893c37 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -14445,6 +14445,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index 332efb105533d..890043f094656 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -56,6 +56,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" def __str__(self) -> str: return self.value @@ -87,6 +88,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = TerminalTIState.UPSTREAM_FAILED # One or more upstream deps failed SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on a trigger + CHECKPOINTED = IntermediateTIState.CHECKPOINTED # Operator paused at a stable checkpoint def __str__(self) -> str: return self.value @@ -130,6 +132,7 @@ class State: UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED SKIPPED = TaskInstanceState.SKIPPED DEFERRED = TaskInstanceState.DEFERRED + CHECKPOINTED = TaskInstanceState.CHECKPOINTED finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED]) unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING]) @@ -157,6 +160,7 @@ class State: TaskInstanceState.REMOVED: "lightgrey", TaskInstanceState.SCHEDULED: "tan", TaskInstanceState.DEFERRED: "mediumpurple", + TaskInstanceState.CHECKPOINTED: "lightyellow", } @classmethod @@ -200,6 +204,7 @@ def color_fg(cls, state): TaskInstanceState.UP_FOR_RETRY, TaskInstanceState.UP_FOR_RESCHEDULE, TaskInstanceState.DEFERRED, + TaskInstanceState.CHECKPOINTED, ] ) """ diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 276c8699de058..d549d78711481 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -841,6 +841,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskInstancesBatchBody(BaseModel): diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 01c8149d1dad8..e894eda0837dd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1274,6 +1274,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..306cd386c38ac 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -198,6 +198,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class PrevSuccessfulDagRunResponse(BaseModel): @@ -365,6 +366,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskStatePutBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 7d42dad5d8502..cbde73db11224 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -165,6 +165,33 @@ class AirflowSkipException(AirflowException): """Raise when the task should be skipped.""" +class AirflowTaskCheckpointed(AirflowException): + """ + Raise when the operator has reached a stable checkpoint and intends to pause. + + The worker reports the task as CHECKPOINTED and persists the optional + ``checkpoint_data`` so a subsequent run can resume from the same point. + + This is the AIP-96 (Resumable Operators) foundation primitive. Auto-resume + semantics, scheduler treatment, and downstream trigger-rule integration are + intentionally deferred to follow-ups so the API shape can be discussed + without committing to a single resumption policy. + + :param checkpoint_data: Arbitrary serializable payload representing the + operator's resume point. Persistence and resume wiring are out of scope + for this foundation; the parameter exists so listeners and operator + authors can iterate against the final shape. + """ + + def __init__(self, checkpoint_data=None): + super().__init__() + self.checkpoint_data = checkpoint_data + + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (), {"checkpoint_data": self.checkpoint_data} + + class AirflowTaskTerminated(BaseException): """Raise when the task execution is terminated.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..ded750db1840c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -754,6 +754,7 @@ class TaskState(BaseModel): TaskInstanceState.FAILED, TaskInstanceState.SKIPPED, TaskInstanceState.REMOVED, + TaskInstanceState.CHECKPOINTED, ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..bfbf124611fb5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1248,6 +1248,7 @@ def run( AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, + AirflowTaskCheckpointed, AirflowTaskTerminated, DagRunTriggerException, DownstreamTasksSkipped, @@ -1342,6 +1343,15 @@ def _on_term(signum, frame): rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.SKIPPED + except AirflowTaskCheckpointed as checkpoint: + log.info("Task checkpointed; reporting CHECKPOINTED state.") + msg = TaskState( + state=TaskInstanceState.CHECKPOINTED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) + state = TaskInstanceState.CHECKPOINTED + ti._checkpoint_data = checkpoint.checkpoint_data # type: ignore[attr-defined] except AirflowRescheduleException as reschedule: log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") msg = RescheduleTask( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..4a6a68165fd3c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4758,6 +4758,35 @@ def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only( mock_supervisor_comms.send.assert_any_call(msg) +class TestTaskCheckpointed: + """AIP-96 foundation: AirflowTaskCheckpointed -> CHECKPOINTED state.""" + + @pytest.mark.parametrize( + "checkpoint_data", + [ + pytest.param(None, id="no-payload"), + pytest.param({"step": 3, "iterator_offset": 1024}, id="dict-payload"), + pytest.param([1, 2, 3], id="list-payload"), + ], + ) + def test_run_returns_checkpointed_state(self, checkpoint_data, create_runtime_ti, mock_supervisor_comms): + """``run()`` reports CHECKPOINTED when the operator raises + ``AirflowTaskCheckpointed``. The exception's ``checkpoint_data`` payload + is preserved on the exception object regardless of shape; persistence + and resume wiring are out of scope for the foundation PR.""" + from airflow.sdk.exceptions import AirflowTaskCheckpointed + + def _raise_checkpointed(): + raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data) + + task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed) + ti = create_runtime_ti(task=task) + + state, _msg, _error = run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + assert state == TaskInstanceState.CHECKPOINTED + + class TestTaskInstanceMetrics: def test_ti_start_metric_emitted(self, create_runtime_ti, mock_supervisor_comms): """Test that ti.start metric is emitted at the beginning of task."""