Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions airflow-core/newsfragments/66445.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Wire ``CHECKPOINTED`` (introduced in #66402) through the supervisor → API server path so the DB row transitions to ``checkpointed`` when an operator raises ``AirflowTaskCheckpointed``.

- ``CHECKPOINTED`` is added to ``STATES_SENT_DIRECTLY`` so the supervisor
does not route it through the terminal-state ``finish()`` endpoint.
- The supervisor's ``TaskState`` handler calls a new
``client.task_instances.checkpoint()`` method when the worker reports
``CHECKPOINTED``.
- The new client method PATCHes ``/task-instances/{id}/state`` with
``TITargetStatePayload(state=CHECKPOINTED)`` — the same shape
``DEFERRED`` and ``UP_FOR_RESCHEDULE`` already use.

``checkpoint_data`` persistence is intentionally not added in this PR. The
right shape (XCom-backed, new TaskInstance JSON column, separate metadata
table) is the open AIP-96 question; this PR ships the state-transition
wiring so the DB lands at ``checkpointed`` while the persistence question
is debated.
Original file line number Diff line number Diff line change
Expand Up @@ -3546,6 +3546,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14445,6 +14445,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,13 @@ def _create_ti_state_update_query_and_update_state(
# clear the next_method and next_kwargs so that none of the retries pick them up
updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
query = query.values(state=updated_state, next_method=None, next_kwargs=None)
elif type(ti_patch_payload).__name__ == "TITargetStatePayload":
# Generic target-state transition (e.g. CHECKPOINTED). No additional
# bookkeeping beyond setting the state — persistence of any
# state-specific payload (e.g. checkpoint_data) is the open AIP-96
# question not addressed here.
updated_state = TaskInstanceState(ti_patch_payload.state.value)
query = query.values(state=updated_state)
else:
raise ValueError(f"Unexpected Payload Type {type(ti_patch_payload)}")

Expand Down
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class IntermediateTIState(str, Enum):
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -87,6 +88,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = TerminalTIState.UPSTREAM_FAILED # One or more upstream deps failed
SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism
DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on a trigger
CHECKPOINTED = IntermediateTIState.CHECKPOINTED # Operator paused at a stable checkpoint

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -130,6 +132,7 @@ class State:
UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED
SKIPPED = TaskInstanceState.SKIPPED
DEFERRED = TaskInstanceState.DEFERRED
CHECKPOINTED = TaskInstanceState.CHECKPOINTED

finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING])
Expand Down Expand Up @@ -157,6 +160,7 @@ class State:
TaskInstanceState.REMOVED: "lightgrey",
TaskInstanceState.SCHEDULED: "tan",
TaskInstanceState.DEFERRED: "mediumpurple",
TaskInstanceState.CHECKPOINTED: "lightyellow",
}

@classmethod
Expand Down Expand Up @@ -200,6 +204,7 @@ def color_fg(cls, state):
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.UP_FOR_RESCHEDULE,
TaskInstanceState.DEFERRED,
TaskInstanceState.CHECKPOINTED,
]
)
"""
Expand Down
1 change: 1 addition & 0 deletions airflow-ctl/src/airflowctl/api/datamodels/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class TaskInstancesBatchBody(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.
Expand Down
16 changes: 16 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
# Create a reschedule state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def checkpoint(
self,
id: uuid.UUID,
end_date: datetime | None = None,
rendered_map_index: str | None = None,
) -> None:
"""Tell the API server that this TI has reached a stable checkpoint and paused."""
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TITargetStatePayload

del (
end_date,
rendered_map_index,
) # reserved for future persistence; current API server does not store these
body = TITargetStatePayload(state=IntermediateTIState.CHECKPOINTED)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class IntermediateTIState(str, Enum):
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class PrevSuccessfulDagRunResponse(BaseModel):
Expand Down Expand Up @@ -365,6 +366,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class TaskStatePutBody(BaseModel):
Expand Down
27 changes: 27 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,33 @@ class AirflowSkipException(AirflowException):
"""Raise when the task should be skipped."""


class AirflowTaskCheckpointed(AirflowException):
"""
Raise when the operator has reached a stable checkpoint and intends to pause.

The worker reports the task as CHECKPOINTED and persists the optional
``checkpoint_data`` so a subsequent run can resume from the same point.

This is the AIP-96 (Resumable Operators) foundation primitive. Auto-resume
semantics, scheduler treatment, and downstream trigger-rule integration are
intentionally deferred to follow-ups so the API shape can be discussed
without committing to a single resumption policy.

:param checkpoint_data: Arbitrary serializable payload representing the
operator's resume point. Persistence and resume wiring are out of scope
for this foundation; the parameter exists so listeners and operator
authors can iterate against the final shape.
"""

def __init__(self, checkpoint_data=None):
super().__init__()
self.checkpoint_data = checkpoint_data

def serialize(self):
cls = self.__class__
return f"{cls.__module__}.{cls.__name__}", (), {"checkpoint_data": self.checkpoint_data}


class AirflowTaskTerminated(BaseException):
"""Raise when the task execution is terminated."""

Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ class TaskState(BaseModel):
TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED,
TaskInstanceState.REMOVED,
TaskInstanceState.CHECKPOINTED,
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
TaskInstanceState.UP_FOR_RESCHEDULE,
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.SUCCESS,
TaskInstanceState.CHECKPOINTED,
SERVER_TERMINATED,
]

Expand Down Expand Up @@ -1397,6 +1398,12 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
if msg.state == TaskInstanceState.CHECKPOINTED:
self.client.task_instances.checkpoint(
id=self.id,
end_date=msg.end_date,
rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
Expand Down
10 changes: 10 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ def run(
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskCheckpointed,
AirflowTaskTerminated,
DagRunTriggerException,
DownstreamTasksSkipped,
Expand Down Expand Up @@ -1342,6 +1343,15 @@ def _on_term(signum, frame):
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.SKIPPED
except AirflowTaskCheckpointed as checkpoint:
log.info("Task checkpointed; reporting CHECKPOINTED state.")
msg = TaskState(
state=TaskInstanceState.CHECKPOINTED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.CHECKPOINTED
ti._checkpoint_data = checkpoint.checkpoint_data # type: ignore[attr-defined]
except AirflowRescheduleException as reschedule:
log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
msg = RescheduleTask(
Expand Down
13 changes: 13 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,19 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
client.task_instances.reschedule(ti_id, msg)

def test_task_instance_checkpoint(self):
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "checkpointed"
return httpx.Response(status_code=204)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.checkpoint(ti_id)

def test_task_instance_up_for_retry(self):
ti_id = uuid6.uuid7()

Expand Down
29 changes: 29 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4758,6 +4758,35 @@ def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only(
mock_supervisor_comms.send.assert_any_call(msg)


class TestTaskCheckpointed:
"""AIP-96 foundation: AirflowTaskCheckpointed -> CHECKPOINTED state."""

@pytest.mark.parametrize(
"checkpoint_data",
[
pytest.param(None, id="no-payload"),
pytest.param({"step": 3, "iterator_offset": 1024}, id="dict-payload"),
pytest.param([1, 2, 3], id="list-payload"),
],
)
def test_run_returns_checkpointed_state(self, checkpoint_data, create_runtime_ti, mock_supervisor_comms):
"""``run()`` reports CHECKPOINTED when the operator raises
``AirflowTaskCheckpointed``. The exception's ``checkpoint_data`` payload
is preserved on the exception object regardless of shape; persistence
and resume wiring are out of scope for the foundation PR."""
from airflow.sdk.exceptions import AirflowTaskCheckpointed

def _raise_checkpointed():
raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data)

task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed)
ti = create_runtime_ti(task=task)

state, _msg, _error = run(ti, context=ti.get_template_context(), log=mock.MagicMock())

assert state == TaskInstanceState.CHECKPOINTED


class TestTaskInstanceMetrics:
def test_ti_start_metric_emitted(self, create_runtime_ti, mock_supervisor_comms):
"""Test that ti.start metric is emitted at the beginning of task."""
Expand Down
Loading