diff --git a/airflow-core/newsfragments/66445.significant.rst b/airflow-core/newsfragments/66445.significant.rst new file mode 100644 index 0000000000000..6eeafa360664e --- /dev/null +++ b/airflow-core/newsfragments/66445.significant.rst @@ -0,0 +1,16 @@ +Wire ``CHECKPOINTED`` (introduced in #66402) through the supervisor → API server path so the DB row transitions to ``checkpointed`` when an operator raises ``AirflowTaskCheckpointed``. + +- ``CHECKPOINTED`` is added to ``STATES_SENT_DIRECTLY`` so the supervisor + does not route it through the terminal-state ``finish()`` endpoint. +- The supervisor's ``TaskState`` handler calls a new + ``client.task_instances.checkpoint()`` method when the worker reports + ``CHECKPOINTED``. +- The new client method PATCHes ``/task-instances/{id}/state`` with + ``TITargetStatePayload(state=CHECKPOINTED)`` — the same shape + ``DEFERRED`` and ``UP_FOR_RESCHEDULE`` already use. + +``checkpoint_data`` persistence is intentionally not added in this PR. The +right shape (XCom-backed, new TaskInstance JSON column, separate metadata +table) is the open AIP-96 question; this PR ships the state-transition +wiring so the DB lands at ``checkpointed`` while the persistence question +is debated. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 2983263bbc59b..bdf69e69be527 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3546,6 +3546,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 5ed96855c24ac..64a120d893c37 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -14445,6 +14445,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 0fccd6ad1b61c..bb96e17b0eaa5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -635,6 +635,13 @@ def _create_ti_state_update_query_and_update_state( # clear the next_method and next_kwargs so that none of the retries pick them up updated_state = TaskInstanceState.UP_FOR_RESCHEDULE query = query.values(state=updated_state, next_method=None, next_kwargs=None) + elif type(ti_patch_payload).__name__ == "TITargetStatePayload": + # Generic target-state transition (e.g. CHECKPOINTED). No additional + # bookkeeping beyond setting the state — persistence of any + # state-specific payload (e.g. checkpoint_data) is the open AIP-96 + # question not addressed here. + updated_state = TaskInstanceState(ti_patch_payload.state.value) + query = query.values(state=updated_state) else: raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}") diff --git a/airflow-core/src/airflow/utils/state.py b/airflow-core/src/airflow/utils/state.py index 332efb105533d..890043f094656 100644 --- a/airflow-core/src/airflow/utils/state.py +++ b/airflow-core/src/airflow/utils/state.py @@ -56,6 +56,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" def __str__(self) -> str: return self.value @@ -87,6 +88,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = TerminalTIState.UPSTREAM_FAILED # One or more upstream deps failed SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on a trigger + CHECKPOINTED = IntermediateTIState.CHECKPOINTED # Operator paused at a stable checkpoint def __str__(self) -> str: return self.value @@ -130,6 +132,7 @@ class State: UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED SKIPPED = TaskInstanceState.SKIPPED DEFERRED = TaskInstanceState.DEFERRED + CHECKPOINTED = TaskInstanceState.CHECKPOINTED finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED]) unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING]) @@ -157,6 +160,7 @@ class State: TaskInstanceState.REMOVED: "lightgrey", TaskInstanceState.SCHEDULED: "tan", TaskInstanceState.DEFERRED: "mediumpurple", + TaskInstanceState.CHECKPOINTED: "lightyellow", } @classmethod @@ -200,6 +204,7 @@ def color_fg(cls, state): TaskInstanceState.UP_FOR_RETRY, TaskInstanceState.UP_FOR_RESCHEDULE, TaskInstanceState.DEFERRED, + TaskInstanceState.CHECKPOINTED, ] ) """ diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index 276c8699de058..d549d78711481 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -841,6 +841,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskInstancesBatchBody(BaseModel): diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 01c8149d1dad8..e894eda0837dd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1274,6 +1274,7 @@ components: - upstream_failed - skipped - deferred + - checkpointed title: TaskInstanceState description: 'All possible states that a Task Instance can be in. diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 54927794bf17e..e4c8f61e0b779 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -281,6 +281,22 @@ def reschedule(self, id: uuid.UUID, msg: RescheduleTask): # Create a reschedule state payload from msg self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + def checkpoint( + self, + id: uuid.UUID, + end_date: datetime | None = None, + rendered_map_index: str | None = None, + ) -> None: + """Tell the API server that this TI has reached a stable checkpoint and paused.""" + from airflow.sdk.api.datamodels._generated import IntermediateTIState, TITargetStatePayload + + del ( + end_date, + rendered_map_index, + ) # reserved for future persistence; current API server does not store these + body = TITargetStatePayload(state=IntermediateTIState.CHECKPOINTED) + self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + def heartbeat(self, id: uuid.UUID, pid: int): body = TIHeartbeatInfo(pid=pid, hostname=get_hostname()) self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json()) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..306cd386c38ac 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -198,6 +198,7 @@ class IntermediateTIState(str, Enum): UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class PrevSuccessfulDagRunResponse(BaseModel): @@ -365,6 +366,7 @@ class TaskInstanceState(str, Enum): UPSTREAM_FAILED = "upstream_failed" SKIPPED = "skipped" DEFERRED = "deferred" + CHECKPOINTED = "checkpointed" class TaskStatePutBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 7d42dad5d8502..cbde73db11224 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -165,6 +165,33 @@ class AirflowSkipException(AirflowException): """Raise when the task should be skipped.""" +class AirflowTaskCheckpointed(AirflowException): + """ + Raise when the operator has reached a stable checkpoint and intends to pause. + + The worker reports the task as CHECKPOINTED and persists the optional + ``checkpoint_data`` so a subsequent run can resume from the same point. + + This is the AIP-96 (Resumable Operators) foundation primitive. Auto-resume + semantics, scheduler treatment, and downstream trigger-rule integration are + intentionally deferred to follow-ups so the API shape can be discussed + without committing to a single resumption policy. + + :param checkpoint_data: Arbitrary serializable payload representing the + operator's resume point. Persistence and resume wiring are out of scope + for this foundation; the parameter exists so listeners and operator + authors can iterate against the final shape. + """ + + def __init__(self, checkpoint_data=None): + super().__init__() + self.checkpoint_data = checkpoint_data + + def serialize(self): + cls = self.__class__ + return f"{cls.__module__}.{cls.__name__}", (), {"checkpoint_data": self.checkpoint_data} + + class AirflowTaskTerminated(BaseException): """Raise when the task execution is terminated.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..ded750db1840c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -754,6 +754,7 @@ class TaskState(BaseModel): TaskInstanceState.FAILED, TaskInstanceState.SKIPPED, TaskInstanceState.REMOVED, + TaskInstanceState.CHECKPOINTED, ] end_date: datetime | None = None type: Literal["TaskState"] = "TaskState" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 375c5a9e30b8e..5ef521990d2b0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -166,6 +166,7 @@ TaskInstanceState.UP_FOR_RESCHEDULE, TaskInstanceState.UP_FOR_RETRY, TaskInstanceState.SUCCESS, + TaskInstanceState.CHECKPOINTED, SERVER_TERMINATED, ] @@ -1397,6 +1398,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() self._rendered_map_index = msg.rendered_map_index + if msg.state == TaskInstanceState.CHECKPOINTED: + self.client.task_instances.checkpoint( + id=self.id, + end_date=msg.end_date, + rendered_map_index=self._rendered_map_index, + ) elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..bfbf124611fb5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1248,6 +1248,7 @@ def run( AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, + AirflowTaskCheckpointed, AirflowTaskTerminated, DagRunTriggerException, DownstreamTasksSkipped, @@ -1342,6 +1343,15 @@ def _on_term(signum, frame): rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.SKIPPED + except AirflowTaskCheckpointed as checkpoint: + log.info("Task checkpointed; reporting CHECKPOINTED state.") + msg = TaskState( + state=TaskInstanceState.CHECKPOINTED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) + state = TaskInstanceState.CHECKPOINTED + ti._checkpoint_data = checkpoint.checkpoint_data # type: ignore[attr-defined] except AirflowRescheduleException as reschedule: log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") msg = RescheduleTask( diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 26e2a7e66bb6e..38c94481cdcd1 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -482,6 +482,19 @@ def handle_request(request: httpx.Request) -> httpx.Response: ) client.task_instances.reschedule(ti_id, msg) + def test_task_instance_checkpoint(self): + ti_id = uuid6.uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/task-instances/{ti_id}/state": + actual_body = json.loads(request.read()) + assert actual_body["state"] == "checkpointed" + return httpx.Response(status_code=204) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + client.task_instances.checkpoint(ti_id) + def test_task_instance_up_for_retry(self): ti_id = uuid6.uuid7() diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..4a6a68165fd3c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4758,6 +4758,35 @@ def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only( mock_supervisor_comms.send.assert_any_call(msg) +class TestTaskCheckpointed: + """AIP-96 foundation: AirflowTaskCheckpointed -> CHECKPOINTED state.""" + + @pytest.mark.parametrize( + "checkpoint_data", + [ + pytest.param(None, id="no-payload"), + pytest.param({"step": 3, "iterator_offset": 1024}, id="dict-payload"), + pytest.param([1, 2, 3], id="list-payload"), + ], + ) + def test_run_returns_checkpointed_state(self, checkpoint_data, create_runtime_ti, mock_supervisor_comms): + """``run()`` reports CHECKPOINTED when the operator raises + ``AirflowTaskCheckpointed``. The exception's ``checkpoint_data`` payload + is preserved on the exception object regardless of shape; persistence + and resume wiring are out of scope for the foundation PR.""" + from airflow.sdk.exceptions import AirflowTaskCheckpointed + + def _raise_checkpointed(): + raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data) + + task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed) + ti = create_runtime_ti(task=task) + + state, _msg, _error = run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + assert state == TaskInstanceState.CHECKPOINTED + + class TestTaskInstanceMetrics: def test_ti_start_metric_emitted(self, create_runtime_ti, mock_supervisor_comms): """Test that ti.start metric is emitted at the beginning of task."""