Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
AIP-96 demonstration — Databricks resumable operator.

A subclass of ``DatabricksSubmitRunOperator`` that survives worker
disruption (pod eviction, OOM kill, node drain) by:

1. Persisting the Databricks ``run_id`` via AIP-103 ``task_state``
immediately after submit.
2. Converting SIGTERM into ``AirflowTaskCheckpointed`` so the worker
transitions to ``CHECKPOINTED`` instead of running ``on_kill``'s
default ``cancel_run`` path.
3. On the next attempt, reading the prior ``run_id`` from
``task_state`` and reconnecting (skipping submit).

This is the v1 integration pattern from AIP-96 v2: roughly 8 lines of
wrapper around the existing operator's submit + poll. The pattern shows
what a real provider integration looks like — not a synthetic example.

NOTE: this file lives under ``tests/system/databricks/`` to match how
other Databricks operator examples are organized in the repo. It is
illustrative and stacks on top of the AIP-96 PR set (#66402, #66410,
#66445); not for merge in this form. Once AIP-96 is accepted, the
upstream-eligible shape would be either a new operator class shipped
alongside ``DatabricksSubmitRunOperator`` in
``providers/databricks/.../operators/databricks.py``, or a
``resumable=True`` flag merged into the existing operator.
"""

from __future__ import annotations

import datetime
import signal
from typing import TYPE_CHECKING

from airflow.providers.databricks.operators.databricks import (
DatabricksSubmitRunOperator,
_handle_databricks_operator_execution,
_handle_deferrable_databricks_operator_execution,
)
from airflow.providers.databricks.utils.databricks import normalise_json_content
from airflow.sdk import DAG
from airflow.sdk.exceptions import AirflowTaskCheckpointed

if TYPE_CHECKING:
from airflow.sdk import Context


class ResumableDatabricksSubmitRunOperator(DatabricksSubmitRunOperator):
"""
Databricks submit-run operator that survives worker disruption.

Differences from the parent operator:

- On the first attempt, persists ``self.run_id`` to AIP-103
``task_state`` immediately after submit.
- On the next attempt (after CHECKPOINTED), reads the prior
``run_id`` from ``task_state`` and skips ``submit_run``,
reconnecting to the existing Databricks job.
- Installs a SIGTERM handler during execute that raises
``AirflowTaskCheckpointed`` instead of letting the default
``on_kill`` cancel the Databricks run.
- Overrides ``on_kill`` to be a no-op when checkpoint-style
preservation is desired (the default still cancels otherwise).
"""

RESUME_KEY = "databricks_run_id"

def execute(self, context: Context):
# AIP-103: read prior run_id, if any.
prior_run_id = context["task_state"].get(self.RESUME_KEY)

if prior_run_id is not None:
self.log.info("Resuming Databricks run_id=%s from prior attempt", prior_run_id)
self.run_id = prior_run_id
else:
# First attempt: submit and persist.
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
and self.json["pipeline_task"].get("pipeline_name")
):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
context["task_state"].set(self.RESUME_KEY, self.run_id)
self.log.info("Submitted Databricks run_id=%s and persisted to task_state", self.run_id)

# Install a SIGTERM handler that signals checkpoint instead of
# cancel-on-kill. A future framework-level helper could install
# this automatically — see AIP-96 v2 reviewer questions.
def _on_sigterm(signum, frame):
raise AirflowTaskCheckpointed(checkpoint_data={"run_id": self.run_id})

prior_handler = signal.signal(signal.SIGTERM, _on_sigterm)
try:
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context)
else:
_handle_databricks_operator_execution(self, self._hook, self.log, context)
finally:
signal.signal(signal.SIGTERM, prior_handler)

# Clear the resume key on success so a future DAG run starts fresh.
context["task_state"].delete(self.RESUME_KEY)

def on_kill(self):
# Default DatabricksSubmitRunOperator.on_kill cancels the run.
# In the resumable variant we PRESERVE the run so the next
# attempt can reconnect via task_state. The SIGTERM handler in
# execute() raises AirflowTaskCheckpointed before this fires;
# this method is a safety net for non-SIGTERM kill paths.
if self.run_id:
self.log.info(
"Task with run_id=%s killed; preserving the Databricks run "
"for next-attempt reconnection (run_id stored in task_state).",
self.run_id,
)


# ---------------------------------------------------------------------------
# Example DAG using the resumable operator. The notebook task path and
# cluster spec are illustrative; replace with real values to run against
# a Databricks workspace.
# ---------------------------------------------------------------------------

with DAG(
dag_id="example_resumable_databricks",
description="AIP-96 demo: Databricks job survives worker disruption mid-execution",
schedule=None,
start_date=datetime.datetime(2026, 5, 1),
catchup=False,
tags=["example", "aip-96", "resumable", "databricks"],
) as dag:
new_cluster_spec = {
"spark_version": "13.3.x-scala2.12",
"node_type_id": "i3.xlarge",
"num_workers": 2,
}

notebook_task = {
"notebook_path": "/Users/example@example.com/aip96_long_running_demo",
}

resumable_databricks_run = ResumableDatabricksSubmitRunOperator(
task_id="resumable_databricks_run",
databricks_conn_id="databricks_default",
new_cluster=new_cluster_spec,
notebook_task=notebook_task,
retries=2,
retry_delay=datetime.timedelta(seconds=30),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Layer 2 e2e — framework-level test of the AIP-96 resumable pattern.

Provider-agnostic. Uses an in-process simulator (no Databricks, no
Kubernetes, no real external job) to exercise the AIP-96 + AIP-103
primitives end-to-end:

- AirflowTaskCheckpointed from airflow.sdk.exceptions
- task_state.set/get/delete (AIP-103) — fake backend used here

Shows that the resume contract composes correctly without per-provider
plumbing. The same shape works for any operator that has a submit-then-poll
external-job structure (Databricks, EMR, Spark-on-K8s, Beam, Dataproc, etc.).

Co-located with the Databricks demo because the discussion lives there;
in a real upstream PR this test would live next to the AIP-96 supervisor
wiring code (#66445) so the framework primitives are tested without
provider deps.
"""

from __future__ import annotations

import secrets
from typing import Any

import pytest

from airflow.sdk import BaseOperator
from airflow.sdk.exceptions import AirflowTaskCheckpointed


class _FakeTaskState:
"""In-memory task_state — stand-in for AIP-103's context['task_state']."""

def __init__(self):
self._d: dict[str, Any] = {}

def get(self, key):
return self._d.get(key)

def set(self, key, value):
self._d[key] = value

def delete(self, key):
self._d.pop(key, None)


class _SimulatedResumableOp(BaseOperator):
"""
Minimal resumable operator using an in-process external job simulator.

Mirrors the structure of ResumableDatabricksSubmitRunOperator without
depending on the Databricks provider. The 'external job' is just
counter state on the operator instance.

Test hooks:
- ``_disrupt_at_step``: if set, raises AirflowTaskCheckpointed when
the poll loop reaches that step (simulates worker SIGTERM mid-poll).
"""

RESUME_KEY = "external_id"

def __init__(self, *, total_steps: int = 3, **kwargs):
super().__init__(**kwargs)
self.total_steps = total_steps
self.submit_calls: list[str] = []
self.poll_calls: list[tuple[str, int]] = []
self._disrupt_at_step: int | None = None

def execute(self, context):
external_id = context["task_state"].get(self.RESUME_KEY)
if external_id is None:
external_id = self._submit()
context["task_state"].set(self.RESUME_KEY, external_id)

for step in range(self.total_steps):
self.poll_calls.append((external_id, step))
if step == self._disrupt_at_step:
raise AirflowTaskCheckpointed(checkpoint_data={"external_id": external_id})

context["task_state"].delete(self.RESUME_KEY)
return external_id

def _submit(self) -> str:
external_id = f"job-{secrets.token_hex(2)}"
self.submit_calls.append(external_id)
return external_id


@pytest.fixture
def ctx():
return {"task_state": _FakeTaskState()}


def test_first_attempt_submits_persists_then_clears_on_success(ctx):
op = _SimulatedResumableOp(task_id="t")

result = op.execute(ctx)

assert len(op.submit_calls) == 1
assert result == op.submit_calls[0]
# Polled all steps.
assert len(op.poll_calls) == op.total_steps
# Cleared on success.
assert ctx["task_state"].get("external_id") is None


def test_disruption_raises_checkpointed_and_persists_external_id(ctx):
op = _SimulatedResumableOp(task_id="t")
op._disrupt_at_step = 1

with pytest.raises(AirflowTaskCheckpointed) as exc_info:
op.execute(ctx)

persisted = ctx["task_state"].get("external_id")
assert persisted is not None
assert persisted == op.submit_calls[0]
assert exc_info.value.checkpoint_data == {"external_id": persisted}
# Polled up to and including the disruption step.
assert len(op.poll_calls) == op._disrupt_at_step + 1


def test_resume_after_checkpoint_skips_submit_and_completes(ctx):
op = _SimulatedResumableOp(task_id="t")

# First attempt — disrupted.
op._disrupt_at_step = 1
with pytest.raises(AirflowTaskCheckpointed):
op.execute(ctx)

persisted = ctx["task_state"].get("external_id")
assert persisted is not None

# Second attempt — no disruption.
op._disrupt_at_step = None
result = op.execute(ctx)

# _submit was NOT called again. The same external_id flowed through.
assert len(op.submit_calls) == 1
assert result == persisted
# task_state cleared on success.
assert ctx["task_state"].get("external_id") is None


def test_repeated_disruption_cycles_preserve_external_id(ctx):
"""Three cycles of disrupt → resume → disrupt → resume → success.

Validates that the framework primitives compose under repeated
interruption — the operator submits exactly once, all subsequent
attempts read the same external_id from task_state.
"""
op = _SimulatedResumableOp(task_id="t", total_steps=10)

# Cycle 1: disrupt at step 2.
op._disrupt_at_step = 2
with pytest.raises(AirflowTaskCheckpointed):
op.execute(ctx)

# Cycle 2: disrupt at step 5.
op._disrupt_at_step = 5
with pytest.raises(AirflowTaskCheckpointed):
op.execute(ctx)

# Cycle 3: success.
op._disrupt_at_step = None
result = op.execute(ctx)

# Exactly one submit across all attempts.
assert len(op.submit_calls) == 1
assert result == op.submit_calls[0]
assert ctx["task_state"].get("external_id") is None
Loading
Loading