Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/66410.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add the ``on_task_instance_checkpointed`` listener hook so listener authors observe checkpoint events alongside running, success, failed, and skipped.
Original file line number Diff line number Diff line change
Expand Up @@ -3546,6 +3546,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14445,6 +14445,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.

Expand Down
20 changes: 20 additions & 0 deletions airflow-core/src/airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ def on_task_instance_skipped(
# [END howto_listen_ti_skipped_task]


# [START howto_listen_ti_checkpointed_task]
@hookimpl
def on_task_instance_checkpointed(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
checkpoint_data,
):
"""
Called when a task reaches a stable checkpoint and pauses.

The operator raised ``AirflowTaskCheckpointed`` from ``execute()`` and the
worker reported ``CHECKPOINTED`` state. ``checkpoint_data`` is the payload
the operator passed when raising; it can be ``None``.
"""
print(f"Task instance reached a checkpoint (data={checkpoint_data!r})")


# [END howto_listen_ti_checkpointed_task]


# [START howto_listen_dagrun_success_task]
@hookimpl
def on_dag_run_success(dag_run: DagRun, msg: str):
Expand Down
5 changes: 5 additions & 0 deletions airflow-core/src/airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class IntermediateTIState(str, Enum):
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -87,6 +88,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = TerminalTIState.UPSTREAM_FAILED # One or more upstream deps failed
SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism
DEFERRED = IntermediateTIState.DEFERRED # Deferrable operator waiting on a trigger
CHECKPOINTED = IntermediateTIState.CHECKPOINTED # Operator paused at a stable checkpoint

def __str__(self) -> str:
return self.value
Expand Down Expand Up @@ -130,6 +132,7 @@ class State:
UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED
SKIPPED = TaskInstanceState.SKIPPED
DEFERRED = TaskInstanceState.DEFERRED
CHECKPOINTED = TaskInstanceState.CHECKPOINTED

finished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.SUCCESS, DagRunState.FAILED])
unfinished_dr_states: frozenset[DagRunState] = frozenset([DagRunState.QUEUED, DagRunState.RUNNING])
Expand Down Expand Up @@ -157,6 +160,7 @@ class State:
TaskInstanceState.REMOVED: "lightgrey",
TaskInstanceState.SCHEDULED: "tan",
TaskInstanceState.DEFERRED: "mediumpurple",
TaskInstanceState.CHECKPOINTED: "lightyellow",
}

@classmethod
Expand Down Expand Up @@ -200,6 +204,7 @@ def color_fg(cls, state):
TaskInstanceState.UP_FOR_RETRY,
TaskInstanceState.UP_FOR_RESCHEDULE,
TaskInstanceState.DEFERRED,
TaskInstanceState.CHECKPOINTED,
]
)
"""
Expand Down
1 change: 1 addition & 0 deletions airflow-ctl/src/airflowctl/api/datamodels/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class TaskInstancesBatchBody(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ components:
- upstream_failed
- skipped
- deferred
- checkpointed
title: TaskInstanceState
description: 'All possible states that a Task Instance can be in.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,25 @@ def on_task_instance_skipped(
:param task_instance: The task instance object (RuntimeTaskInstance when called
from task execution context, TaskInstance when called from API server)
"""


@hookspec
def on_task_instance_checkpointed(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
checkpoint_data,
):
"""
Execute when a task instance reaches a stable checkpoint and pauses.

Fires when an operator raises ``AirflowTaskCheckpointed`` and the worker
reports ``CHECKPOINTED`` state.

:param previous_state: Previous state of the task instance (typically RUNNING)
:param task_instance: The task instance object (RuntimeTaskInstance when
called from task execution context, TaskInstance when called from API
server)
:param checkpoint_data: The serializable payload the operator passed to
``AirflowTaskCheckpointed(checkpoint_data=...)``, or ``None`` if no
payload was supplied.
"""
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class IntermediateTIState(str, Enum):
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class PrevSuccessfulDagRunResponse(BaseModel):
Expand Down Expand Up @@ -365,6 +366,7 @@ class TaskInstanceState(str, Enum):
UPSTREAM_FAILED = "upstream_failed"
SKIPPED = "skipped"
DEFERRED = "deferred"
CHECKPOINTED = "checkpointed"


class TaskStatePutBody(BaseModel):
Expand Down
27 changes: 27 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,33 @@ class AirflowSkipException(AirflowException):
"""Raise when the task should be skipped."""


class AirflowTaskCheckpointed(AirflowException):
"""
Raise when the operator has reached a stable checkpoint and intends to pause.

The worker reports the task as CHECKPOINTED and persists the optional
``checkpoint_data`` so a subsequent run can resume from the same point.

This is the AIP-96 (Resumable Operators) foundation primitive. Auto-resume
semantics, scheduler treatment, and downstream trigger-rule integration are
intentionally deferred to follow-ups so the API shape can be discussed
without committing to a single resumption policy.

:param checkpoint_data: Arbitrary serializable payload representing the
operator's resume point. Persistence and resume wiring are out of scope
for this foundation; the parameter exists so listeners and operator
authors can iterate against the final shape.
"""

def __init__(self, checkpoint_data=None):
super().__init__()
self.checkpoint_data = checkpoint_data

def serialize(self):
cls = self.__class__
return f"{cls.__module__}.{cls.__name__}", (), {"checkpoint_data": self.checkpoint_data}


class AirflowTaskTerminated(BaseException):
"""Raise when the task execution is terminated."""

Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ class TaskState(BaseModel):
TaskInstanceState.FAILED,
TaskInstanceState.SKIPPED,
TaskInstanceState.REMOVED,
TaskInstanceState.CHECKPOINTED,
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"
Expand Down
19 changes: 19 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ def run(
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskCheckpointed,
AirflowTaskTerminated,
DagRunTriggerException,
DownstreamTasksSkipped,
Expand Down Expand Up @@ -1342,6 +1343,15 @@ def _on_term(signum, frame):
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.SKIPPED
except AirflowTaskCheckpointed as checkpoint:
log.info("Task checkpointed; reporting CHECKPOINTED state.")
msg = TaskState(
state=TaskInstanceState.CHECKPOINTED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.CHECKPOINTED
ti._checkpoint_data = checkpoint.checkpoint_data # type: ignore[attr-defined]
except AirflowRescheduleException as reschedule:
log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
msg = RescheduleTask(
Expand Down Expand Up @@ -1932,6 +1942,15 @@ def finalize(
log.exception("error calling listener")
if error and task.email_on_failure and task.email:
_send_error_email_notification(task, ti, context, error, log)
elif state == TaskInstanceState.CHECKPOINTED:
try:
get_listener_manager().hook.on_task_instance_checkpointed(
previous_state=TaskInstanceState.RUNNING,
task_instance=ti,
checkpoint_data=getattr(ti, "_checkpoint_data", None),
)
except Exception:
log.exception("error calling listener")

try:
get_listener_manager().hook.before_stopping(component=TaskRunnerMarker())
Expand Down
63 changes: 63 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4758,6 +4758,69 @@ def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only(
mock_supervisor_comms.send.assert_any_call(msg)


class TestTaskCheckpointed:
@pytest.mark.parametrize(
"checkpoint_data",
[
pytest.param(None, id="no-payload"),
pytest.param({"step": 3, "iterator_offset": 1024}, id="dict-payload"),
pytest.param([1, 2, 3], id="list-payload"),
],
)
def test_run_returns_checkpointed_state(self, checkpoint_data, create_runtime_ti, mock_supervisor_comms):
"""``run()`` reports CHECKPOINTED when the operator raises
``AirflowTaskCheckpointed``."""
from airflow.sdk.exceptions import AirflowTaskCheckpointed

def _raise_checkpointed():
raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data)

task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed)
ti = create_runtime_ti(task=task)

state, _msg, _error = run(ti, context=ti.get_template_context(), log=mock.MagicMock())

assert state == TaskInstanceState.CHECKPOINTED

@pytest.mark.parametrize(
"checkpoint_data",
[
pytest.param(None, id="no-payload"),
pytest.param({"step": 7}, id="dict-payload"),
],
)
def test_listener_receives_checkpoint_data(
self, checkpoint_data, create_runtime_ti, mock_supervisor_comms, listener_manager
):
"""``finalize()`` invokes ``on_task_instance_checkpointed`` and forwards
the operator-supplied ``checkpoint_data``."""
from airflow.sdk.exceptions import AirflowTaskCheckpointed

received = {"called": 0, "data": "<unset>"}

class CheckpointListener:
@hookimpl
def on_task_instance_checkpointed(self, previous_state, task_instance, checkpoint_data):
received["called"] += 1
received["data"] = checkpoint_data

listener_manager(CheckpointListener())

def _raise_checkpointed():
raise AirflowTaskCheckpointed(checkpoint_data=checkpoint_data)

task = PythonOperator(task_id="checkpointed_task", python_callable=_raise_checkpointed)
ti = create_runtime_ti(task=task)
log = mock.MagicMock()
context = ti.get_template_context()
state, _msg, _error = run(ti, context=context, log=log)
finalize(ti, state, context, log)

assert state == TaskInstanceState.CHECKPOINTED
assert received["called"] == 1
assert received["data"] == checkpoint_data


class TestTaskInstanceMetrics:
def test_ti_start_metric_emitted(self, create_runtime_ti, mock_supervisor_comms):
"""Test that ti.start metric is emitted at the beginning of task."""
Expand Down
Loading