diff --git a/providers/databricks/tests/system/databricks/example_resumable_databricks.py b/providers/databricks/tests/system/databricks/example_resumable_databricks.py new file mode 100644 index 0000000000000..72529cf6d6378 --- /dev/null +++ b/providers/databricks/tests/system/databricks/example_resumable_databricks.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +AIP-96 demonstration — Databricks resumable operator. + +A subclass of ``DatabricksSubmitRunOperator`` that survives worker +disruption (pod eviction, OOM kill, node drain) by: + + 1. Persisting the Databricks ``run_id`` via AIP-103 ``task_state`` + immediately after submit. + 2. Converting SIGTERM into ``AirflowTaskCheckpointed`` so the worker + transitions to ``CHECKPOINTED`` instead of running ``on_kill``'s + default ``cancel_run`` path. + 3. On the next attempt, reading the prior ``run_id`` from + ``task_state`` and reconnecting (skipping submit). + +This is the v1 integration pattern from AIP-96 v2: roughly 8 lines of +wrapper around the existing operator's submit + poll. The pattern shows +what a real provider integration looks like — not a synthetic example. + +NOTE: this file lives under ``tests/system/databricks/`` to match how +other Databricks operator examples are organized in the repo. It is +illustrative and stacks on top of the AIP-96 PR set (#66402, #66410, +#66445); not for merge in this form. Once AIP-96 is accepted, the +upstream-eligible shape would be either a new operator class shipped +alongside ``DatabricksSubmitRunOperator`` in +``providers/databricks/.../operators/databricks.py``, or a +``resumable=True`` flag merged into the existing operator. +""" + +from __future__ import annotations + +import datetime +import signal +from typing import TYPE_CHECKING + +from airflow.providers.databricks.operators.databricks import ( + DatabricksSubmitRunOperator, + _handle_databricks_operator_execution, + _handle_deferrable_databricks_operator_execution, +) +from airflow.providers.databricks.utils.databricks import normalise_json_content +from airflow.sdk import DAG +from airflow.sdk.exceptions import AirflowTaskCheckpointed + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class ResumableDatabricksSubmitRunOperator(DatabricksSubmitRunOperator): + """ + Databricks submit-run operator that survives worker disruption. + + Differences from the parent operator: + + - On the first attempt, persists ``self.run_id`` to AIP-103 + ``task_state`` immediately after submit. + - On the next attempt (after CHECKPOINTED), reads the prior + ``run_id`` from ``task_state`` and skips ``submit_run``, + reconnecting to the existing Databricks job. + - Installs a SIGTERM handler during execute that raises + ``AirflowTaskCheckpointed`` instead of letting the default + ``on_kill`` cancel the Databricks run. + - Overrides ``on_kill`` to be a no-op when checkpoint-style + preservation is desired (the default still cancels otherwise). + """ + + RESUME_KEY = "databricks_run_id" + + def execute(self, context: Context): + # AIP-103: read prior run_id, if any. + prior_run_id = context["task_state"].get(self.RESUME_KEY) + + if prior_run_id is not None: + self.log.info("Resuming Databricks run_id=%s from prior attempt", prior_run_id) + self.run_id = prior_run_id + else: + # First attempt: submit and persist. + if ( + "pipeline_task" in self.json + and self.json["pipeline_task"].get("pipeline_id") is None + and self.json["pipeline_task"].get("pipeline_name") + ): + pipeline_name = self.json["pipeline_task"]["pipeline_name"] + self.json["pipeline_task"]["pipeline_id"] = self._hook.find_pipeline_id_by_name(pipeline_name) + del self.json["pipeline_task"]["pipeline_name"] + json_normalised = normalise_json_content(self.json) + self.run_id = self._hook.submit_run(json_normalised) + context["task_state"].set(self.RESUME_KEY, self.run_id) + self.log.info("Submitted Databricks run_id=%s and persisted to task_state", self.run_id) + + # Install a SIGTERM handler that signals checkpoint instead of + # cancel-on-kill. A future framework-level helper could install + # this automatically — see AIP-96 v2 reviewer questions. + def _on_sigterm(signum, frame): + raise AirflowTaskCheckpointed(checkpoint_data={"run_id": self.run_id}) + + prior_handler = signal.signal(signal.SIGTERM, _on_sigterm) + try: + if self.deferrable: + _handle_deferrable_databricks_operator_execution(self, self._hook, self.log, context) + else: + _handle_databricks_operator_execution(self, self._hook, self.log, context) + finally: + signal.signal(signal.SIGTERM, prior_handler) + + # Clear the resume key on success so a future DAG run starts fresh. + context["task_state"].delete(self.RESUME_KEY) + + def on_kill(self): + # Default DatabricksSubmitRunOperator.on_kill cancels the run. + # In the resumable variant we PRESERVE the run so the next + # attempt can reconnect via task_state. The SIGTERM handler in + # execute() raises AirflowTaskCheckpointed before this fires; + # this method is a safety net for non-SIGTERM kill paths. + if self.run_id: + self.log.info( + "Task with run_id=%s killed; preserving the Databricks run " + "for next-attempt reconnection (run_id stored in task_state).", + self.run_id, + ) + + +# --------------------------------------------------------------------------- +# Example DAG using the resumable operator. The notebook task path and +# cluster spec are illustrative; replace with real values to run against +# a Databricks workspace. +# --------------------------------------------------------------------------- + +with DAG( + dag_id="example_resumable_databricks", + description="AIP-96 demo: Databricks job survives worker disruption mid-execution", + schedule=None, + start_date=datetime.datetime(2026, 5, 1), + catchup=False, + tags=["example", "aip-96", "resumable", "databricks"], +) as dag: + new_cluster_spec = { + "spark_version": "13.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + } + + notebook_task = { + "notebook_path": "/Users/example@example.com/aip96_long_running_demo", + } + + resumable_databricks_run = ResumableDatabricksSubmitRunOperator( + task_id="resumable_databricks_run", + databricks_conn_id="databricks_default", + new_cluster=new_cluster_spec, + notebook_task=notebook_task, + retries=2, + retry_delay=datetime.timedelta(seconds=30), + ) diff --git a/providers/databricks/tests/system/databricks/test_aip96_resumable_pattern.py b/providers/databricks/tests/system/databricks/test_aip96_resumable_pattern.py new file mode 100644 index 0000000000000..9a37d64a9ba5c --- /dev/null +++ b/providers/databricks/tests/system/databricks/test_aip96_resumable_pattern.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Layer 2 e2e — framework-level test of the AIP-96 resumable pattern. + +Provider-agnostic. Uses an in-process simulator (no Databricks, no +Kubernetes, no real external job) to exercise the AIP-96 + AIP-103 +primitives end-to-end: + + - AirflowTaskCheckpointed from airflow.sdk.exceptions + - task_state.set/get/delete (AIP-103) — fake backend used here + +Shows that the resume contract composes correctly without per-provider +plumbing. The same shape works for any operator that has a submit-then-poll +external-job structure (Databricks, EMR, Spark-on-K8s, Beam, Dataproc, etc.). + +Co-located with the Databricks demo because the discussion lives there; +in a real upstream PR this test would live next to the AIP-96 supervisor +wiring code (#66445) so the framework primitives are tested without +provider deps. +""" + +from __future__ import annotations + +import secrets +from typing import Any + +import pytest + +from airflow.sdk import BaseOperator +from airflow.sdk.exceptions import AirflowTaskCheckpointed + + +class _FakeTaskState: + """In-memory task_state — stand-in for AIP-103's context['task_state'].""" + + def __init__(self): + self._d: dict[str, Any] = {} + + def get(self, key): + return self._d.get(key) + + def set(self, key, value): + self._d[key] = value + + def delete(self, key): + self._d.pop(key, None) + + +class _SimulatedResumableOp(BaseOperator): + """ + Minimal resumable operator using an in-process external job simulator. + + Mirrors the structure of ResumableDatabricksSubmitRunOperator without + depending on the Databricks provider. The 'external job' is just + counter state on the operator instance. + + Test hooks: + - ``_disrupt_at_step``: if set, raises AirflowTaskCheckpointed when + the poll loop reaches that step (simulates worker SIGTERM mid-poll). + """ + + RESUME_KEY = "external_id" + + def __init__(self, *, total_steps: int = 3, **kwargs): + super().__init__(**kwargs) + self.total_steps = total_steps + self.submit_calls: list[str] = [] + self.poll_calls: list[tuple[str, int]] = [] + self._disrupt_at_step: int | None = None + + def execute(self, context): + external_id = context["task_state"].get(self.RESUME_KEY) + if external_id is None: + external_id = self._submit() + context["task_state"].set(self.RESUME_KEY, external_id) + + for step in range(self.total_steps): + self.poll_calls.append((external_id, step)) + if step == self._disrupt_at_step: + raise AirflowTaskCheckpointed(checkpoint_data={"external_id": external_id}) + + context["task_state"].delete(self.RESUME_KEY) + return external_id + + def _submit(self) -> str: + external_id = f"job-{secrets.token_hex(2)}" + self.submit_calls.append(external_id) + return external_id + + +@pytest.fixture +def ctx(): + return {"task_state": _FakeTaskState()} + + +def test_first_attempt_submits_persists_then_clears_on_success(ctx): + op = _SimulatedResumableOp(task_id="t") + + result = op.execute(ctx) + + assert len(op.submit_calls) == 1 + assert result == op.submit_calls[0] + # Polled all steps. + assert len(op.poll_calls) == op.total_steps + # Cleared on success. + assert ctx["task_state"].get("external_id") is None + + +def test_disruption_raises_checkpointed_and_persists_external_id(ctx): + op = _SimulatedResumableOp(task_id="t") + op._disrupt_at_step = 1 + + with pytest.raises(AirflowTaskCheckpointed) as exc_info: + op.execute(ctx) + + persisted = ctx["task_state"].get("external_id") + assert persisted is not None + assert persisted == op.submit_calls[0] + assert exc_info.value.checkpoint_data == {"external_id": persisted} + # Polled up to and including the disruption step. + assert len(op.poll_calls) == op._disrupt_at_step + 1 + + +def test_resume_after_checkpoint_skips_submit_and_completes(ctx): + op = _SimulatedResumableOp(task_id="t") + + # First attempt — disrupted. + op._disrupt_at_step = 1 + with pytest.raises(AirflowTaskCheckpointed): + op.execute(ctx) + + persisted = ctx["task_state"].get("external_id") + assert persisted is not None + + # Second attempt — no disruption. + op._disrupt_at_step = None + result = op.execute(ctx) + + # _submit was NOT called again. The same external_id flowed through. + assert len(op.submit_calls) == 1 + assert result == persisted + # task_state cleared on success. + assert ctx["task_state"].get("external_id") is None + + +def test_repeated_disruption_cycles_preserve_external_id(ctx): + """Three cycles of disrupt → resume → disrupt → resume → success. + + Validates that the framework primitives compose under repeated + interruption — the operator submits exactly once, all subsequent + attempts read the same external_id from task_state. + """ + op = _SimulatedResumableOp(task_id="t", total_steps=10) + + # Cycle 1: disrupt at step 2. + op._disrupt_at_step = 2 + with pytest.raises(AirflowTaskCheckpointed): + op.execute(ctx) + + # Cycle 2: disrupt at step 5. + op._disrupt_at_step = 5 + with pytest.raises(AirflowTaskCheckpointed): + op.execute(ctx) + + # Cycle 3: success. + op._disrupt_at_step = None + result = op.execute(ctx) + + # Exactly one submit across all attempts. + assert len(op.submit_calls) == 1 + assert result == op.submit_calls[0] + assert ctx["task_state"].get("external_id") is None diff --git a/providers/databricks/tests/system/databricks/test_resumable_databricks_demo.py b/providers/databricks/tests/system/databricks/test_resumable_databricks_demo.py new file mode 100644 index 0000000000000..2b667da1599af --- /dev/null +++ b/providers/databricks/tests/system/databricks/test_resumable_databricks_demo.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Layer 1 e2e — unit test for ResumableDatabricksSubmitRunOperator (AIP-96 demo). + +Mock-based: no real Databricks workspace required. Asserts the resume contract: + + - First execute(): submit_run called, run_id stored in task_state. + - SIGTERM during poll raises AirflowTaskCheckpointed. + - Second execute() (after CHECKPOINTED): submit_run NOT called; prior run_id reused. + - Success path: task_state cleared. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.sdk.exceptions import AirflowTaskCheckpointed + +from system.databricks.example_resumable_databricks import ( + ResumableDatabricksSubmitRunOperator, +) + + +class _FakeTaskState: + """In-memory task_state for tests.""" + + def __init__(self): + self._d: dict[str, object] = {} + + def get(self, key): + return self._d.get(key) + + def set(self, key, value): + self._d[key] = value + + def delete(self, key): + self._d.pop(key, None) + + +def _make_op(**kwargs): + return ResumableDatabricksSubmitRunOperator( + task_id="resumable_databricks_run", + databricks_conn_id="databricks_default", + new_cluster={ + "spark_version": "13.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + notebook_task={"notebook_path": "/Users/x/y"}, + **kwargs, + ) + + +@pytest.fixture +def fake_context(): + return {"task_state": _FakeTaskState()} + + +def _patch_hook(op, run_id: int = 12345): + """Replace the operator's _hook property with a MagicMock returning run_id.""" + hook = MagicMock() + hook.submit_run.return_value = run_id + # Override the property descriptor on the instance via __dict__ + object.__setattr__(op, "_hook", hook) + return hook + + +def test_first_attempt_submits_and_persists_run_id(fake_context): + op = _make_op() + hook = _patch_hook(op, run_id=12345) + + target = "system.databricks.example_resumable_databricks._handle_databricks_operator_execution" + with patch(target): + op.execute(fake_context) + + assert hook.submit_run.call_count == 1 + # Cleared on success. + assert fake_context["task_state"].get("databricks_run_id") is None + + +def test_resume_after_checkpoint_skips_submit_and_reuses_run_id(fake_context): + op = _make_op() + hook = _patch_hook(op, run_id=99999) + + # Pre-seed task_state as if the prior attempt was disrupted. + fake_context["task_state"].set("databricks_run_id", 12345) + + target = "system.databricks.example_resumable_databricks._handle_databricks_operator_execution" + with patch(target): + op.execute(fake_context) + + # submit_run NOT called — we reused the prior run_id. + assert hook.submit_run.call_count == 0 + # Operator's run_id was restored from task_state. + assert op.run_id == 12345 + # task_state cleared on success. + assert fake_context["task_state"].get("databricks_run_id") is None + + +def test_disruption_during_poll_raises_checkpointed_and_run_id_persists(fake_context): + op = _make_op() + hook = _patch_hook(op, run_id=12345) + + def _simulate_disruption(operator, _hook, _log, _context): + # Simulate a SIGTERM-handler raise during the poll loop. + raise AirflowTaskCheckpointed(checkpoint_data={"run_id": operator.run_id}) + + target = "system.databricks.example_resumable_databricks._handle_databricks_operator_execution" + with patch(target, side_effect=_simulate_disruption): + with pytest.raises(AirflowTaskCheckpointed) as exc_info: + op.execute(fake_context) + + # submit_run called once (first attempt), run_id persisted before disruption. + assert hook.submit_run.call_count == 1 + assert fake_context["task_state"].get("databricks_run_id") == 12345 + assert exc_info.value.checkpoint_data == {"run_id": 12345} + + +def test_on_kill_does_not_cancel_run_for_resumable_variant(): + op = _make_op() + hook = _patch_hook(op) + op.run_id = 12345 + + op.on_kill() + + # Critical: the resumable variant must NOT cancel the Databricks run on kill. + # That's the whole point — preserve the external job for next-attempt reconnection. + assert hook.cancel_run.call_count == 0