Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf
from airflow.providers.common.compat.sdk import BaseOperator, BaseOperatorLink, XCom, conf
from airflow.providers.dbt.cloud.hooks.dbt import (
DbtCloudHook,
DbtCloudJobRunException,
Expand Down Expand Up @@ -250,16 +250,16 @@ def execute(self, context: Context):
# execution_timeout is a hard task-level limit (cancels the job),
# while timeout only limits how long we wait for the job to finish.
# If both are set, the earliest deadline wins.
end_time = time.time() + self.timeout
end_time = time.monotonic() + self.timeout
execution_deadline = None
if self.execution_timeout:
execution_deadline = time.time() + self.execution_timeout.total_seconds()
if self.execution_timeout is not None:
execution_deadline = time.monotonic() + self.execution_timeout.total_seconds()

job_run_info = JobRunInfo(account_id=self.account_id, run_id=self.run_id)
job_run_status = self.hook.get_job_run_status(**job_run_info)
if not DbtCloudJobRunStatus.is_terminal(job_run_status):
self.defer(
timeout=None,
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=self.run_id,
Expand Down Expand Up @@ -293,8 +293,22 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int:

# Enforce execution_timeout semantics in deferrable mode by cancelling the job.
if event["status"] == "timeout":
self.hook.cancel_job_run(account_id=self.account_id, run_id=self.run_id)
raise AirflowException(f"Job run {self.run_id} has timed out.")
if self.run_id is not None:
self.log.info("Cancelling DBT job run %s due to execution timeout", self.run_id)

# Attempt best-effort job run cancellation.
try:
self.hook.cancel_job_run(account_id=self.account_id, run_id=self.run_id)
except Exception:
self.log.warning(
"Failed to cancel DBT job run %s after timeout",
self.run_id,
exc_info=True,
)
else:
self.log.warning("No run_id found; skipping cancellation")

raise DbtCloudJobRunException(f"Job run {self.run_id} has timed out.")

self.log.info(event["message"])
return int(event["run_id"])
Expand All @@ -303,7 +317,15 @@ def on_kill(self) -> None:
if not self.run_id:
return

self.hook.cancel_job_run(account_id=self.account_id, run_id=self.run_id)
# Attempt best-effort job run cancellation.
try:
self.hook.cancel_job_run(account_id=self.account_id, run_id=self.run_id)
except Exception:
self.log.warning(
"Failed to cancel DBT job run %s during on_kill",
self.run_id,
exc_info=True,
)

# Attempt best-effort confirmation of cancellation.
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,43 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to Dbt, polls for the pipeline run status."""
hook = DbtCloudHook(self.conn_id, **self.hook_params)
try:
while await self.is_still_running(hook):
while True:
now = time.monotonic()

job_run_status = await hook.get_job_status(self.run_id, self.account_id)

if job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
yield TriggerEvent(
{
"status": "success",
"message": f"Job run {self.run_id} has completed successfully.",
"run_id": self.run_id,
}
)
return

elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value:
yield TriggerEvent(
{
"status": "cancelled",
"message": f"Job run {self.run_id} has been cancelled.",
"run_id": self.run_id,
}
)
return

elif job_run_status == DbtCloudJobRunStatus.ERROR.value:
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.run_id} has failed.",
"run_id": self.run_id,
}
)
return

if self.execution_deadline is not None:
if self.execution_deadline < time.time():
if self.execution_deadline <= now:
yield TriggerEvent(
{
"status": "timeout",
Expand All @@ -91,11 +125,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
)
return

if self.end_time < time.time():
# Perform a final status check before declaring timeout, in case the
# job completed between the last poll and the timeout expiry.
if not await self.is_still_running(hook):
break
if self.end_time <= now:
yield TriggerEvent(
{
"status": "error",
Expand All @@ -105,34 +135,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
}
)
return

await asyncio.sleep(self.poll_interval)
job_run_status = await hook.get_job_status(self.run_id, self.account_id)
if job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
yield TriggerEvent(
{
"status": "success",
"message": f"Job run {self.run_id} has completed successfully.",
"run_id": self.run_id,
}
)
elif job_run_status == DbtCloudJobRunStatus.CANCELLED.value:
yield TriggerEvent(
{
"status": "cancelled",
"message": f"Job run {self.run_id} has been cancelled.",
"run_id": self.run_id,
}
)
else:
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.run_id} has failed.",
"run_id": self.run_id,
}
)

except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
return

async def is_still_running(self, hook: DbtCloudHook) -> bool:
"""Check whether the submitted job is running."""
Expand Down
133 changes: 130 additions & 3 deletions providers/dbt/cloud/tests/unit/dbt/cloud/operators/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest

from airflow.models import DAG, Connection
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, timezone
from airflow.providers.common.compat.sdk import TaskDeferred, timezone
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook, DbtCloudJobRunException, DbtCloudJobRunStatus
from airflow.providers.dbt.cloud.operators.dbt import (
DbtCloudGetJobRunArtifactOperator,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_execute_failed_before_getting_deferred(
def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_job_run_status, status):
"""
Asserts that a task is deferred and an DbtCloudRunJobTrigger will be fired
when the DbtCloudRunJobOperator has deferrable param set to True
when the DbtCloudRunJobOperator has deferrable param set to True.
"""
mock_job_run_status.return_value = status
dbt_op = DbtCloudRunJobOperator(
Expand All @@ -214,6 +214,40 @@ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_jo
dbt_op.execute(MagicMock())
assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger"

def test_execute_complete_timeout_without_run_id(self):
"""
Verify that when a deferrable dbt job emits a timeout event with no run_id,
the operator cancels the job and fails.
"""

operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=ACCOUNT_ID_CONN,
job_id=JOB_ID,
dag=self.dag,
deferrable=True,
)

# Pretend the job was already triggered.
operator.run_id = None

# Mock the hook so we can assert cancellation.
operator.hook = MagicMock()

timeout_event = {
"status": "timeout",
"run_id": None,
"message": "Job run timed out.",
}

with pytest.raises(DbtCloudJobRunException):
operator.execute_complete(
context=self.mock_context,
event=timeout_event,
)

operator.hook.cancel_job_run.assert_not_called()

def test_execute_complete_timeout_cancels_job(self):
"""
Verify that when a deferrable dbt job emits a timeout event,
Expand All @@ -239,7 +273,45 @@ def test_execute_complete_timeout_cancels_job(self):
"message": "Job run timed out.",
}

with pytest.raises(AirflowException, match="has timed out"):
with pytest.raises(DbtCloudJobRunException, match="has timed out"):
operator.execute_complete(
context=self.mock_context,
event=timeout_event,
)

operator.hook.cancel_job_run.assert_called_once_with(
account_id=operator.account_id,
run_id=RUN_ID,
)

def test_execute_complete_timeout_cancel_job_does_not_mask_original_error(self):
"""
Verify that when a deferrable dbt job is cancelled after a timeout event is received,
the original error is not masked.
"""
operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=ACCOUNT_ID_CONN,
job_id=JOB_ID,
dag=self.dag,
deferrable=True,
)

# Pretend the job was already triggered.
operator.run_id = RUN_ID

# Mock the hook so we can assert cancellation.
operator.hook = MagicMock()

operator.hook.cancel_job_run.side_effect = Exception("Cancellation failed")

timeout_event = {
"status": "timeout",
"run_id": RUN_ID,
"message": "Job run timed out.",
}

with pytest.raises(DbtCloudJobRunException, match="has timed out"):
operator.execute_complete(
context=self.mock_context,
event=timeout_event,
Expand Down Expand Up @@ -689,6 +761,61 @@ def test_custom_trigger_reason(self, mock_run_job, conn_id, account_id):
additional_run_config=self.config["additional_run_config"],
)

def test_on_kill_cancels_job_and_confirms_success(self):
operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=ACCOUNT_ID_CONN,
job_id=JOB_ID,
dag=self.dag,
)

operator.run_id = RUN_ID
operator.hook = MagicMock()

# Simulate successful cancellation confirmation.
operator.hook.wait_for_job_run_status.return_value = True

operator.on_kill()

operator.hook.cancel_job_run.assert_called_once_with(
account_id=operator.account_id,
run_id=RUN_ID,
)

operator.hook.wait_for_job_run_status.assert_called_once_with(
run_id=RUN_ID,
account_id=operator.account_id,
expected_statuses=DbtCloudJobRunStatus.CANCELLED.value,
check_interval=operator.check_interval,
timeout=operator.timeout,
)

def test_on_kill_best_effort_cancellation_does_not_raise(self):
operator = DbtCloudRunJobOperator(
task_id=TASK_ID,
dbt_cloud_conn_id=ACCOUNT_ID_CONN,
job_id=JOB_ID,
dag=self.dag,
)

operator.run_id = RUN_ID
operator.hook = MagicMock()

# Simulate cancellation failure.
operator.hook.cancel_job_run.side_effect = Exception("Cancellation failed")

# Simulate confirmation also failing (normal path).
operator.hook.wait_for_job_run_status.side_effect = DbtCloudJobRunException("Still running")

operator.on_kill()

operator.hook.cancel_job_run.assert_called_once_with(
account_id=operator.account_id,
run_id=RUN_ID,
)

operator.hook.wait_for_job_run_status.assert_called_once()

@pytest.mark.parametrize(
("conn_id", "account_id"),
[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
Expand Down
Loading
Loading