Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions airflow-core/src/airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select, TextClause

from airflow.models.taskinstancekey import TaskInstanceKey


XCOM_RETURN_KEY = "return_value"

Expand Down Expand Up @@ -332,6 +334,51 @@ def get_many(
return query.limit(limit)
return query

@classmethod
@provide_session
def get_one(
cls,
*,
key: str,
dag_id: str,
task_id: str,
run_id: str,
map_index: int | None = None,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
) -> Any | None:
"""Retrieve and deserialize a single XCom value from the metadata database."""
result = session.execute(
cls.get_many(
key=key,
dag_ids=dag_id,
task_ids=task_id,
run_id=run_id,
map_indexes=map_index,
include_prior_dates=include_prior_dates,
limit=1,
).with_only_columns(cls.value)
).first()
if result is None:
return None
return cls.deserialize_value(result)

@classmethod
def get_value(
cls,
*,
ti_key: TaskInstanceKey,
key: str,
) -> Any | None:
"""Retrieve and deserialize an XCom value for a task instance key."""
return cls.get_one(
key=key,
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
)

@staticmethod
def serialize_value(
value: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ class TryNumberPlugin(AirflowPlugin):
operator_extra_links = [TryNumberLink()]


class XComModelLink(BaseOperatorLink):
name = "XCom Model"
operators = [CustomOperator]

def get_link(self, operator, *, ti_key):
from airflow.providers.common.compat.sdk import XComModel

return XComModel.get_value(key="plugin_link", ti_key=ti_key) or ""


class XComModelPlugin(AirflowPlugin):
name = "xcom_model_plugin"
operator_extra_links = [XComModelLink()]


@pytest.mark.mock_plugin_manager(plugins=[])
class TestGetExtraLinks:
dag_id = "TEST_DAG_ID"
Expand Down Expand Up @@ -265,6 +280,23 @@ def test_should_respond_200_support_plugins(self, test_client):
).model_dump()
)

@pytest.mark.mock_plugin_manager(plugins=[XComModelPlugin])
def test_plugin_link_can_read_xcom_with_xcommodel(self, test_client):
XCom.set(
key="plugin_link",
value="https://example.com/plugin-link",
task_id=self.task_single_link,
dag_id=self.dag_id,
run_id=self.dag_run_id,
)

response = test_client.get(
f"/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links",
)

assert response.status_code == 200
assert response.json()["extra_links"]["XCom Model"] == "https://example.com/plugin-link"

def test_should_respond_200_mapped_task_instance(self, test_client, session):
for map_index, value in enumerate(["TEST_LINK_VALUE_1", "TEST_LINK_VALUE_2"]):
XCom.set(
Expand Down
10 changes: 10 additions & 0 deletions airflow-core/tests/unit/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ def test_xcom_get_one(self, session, task_instance):
).first()
assert XComModel.deserialize_value(stored_value) == {"key": "value"}

def test_xcommodel_get_value(self, task_instance, push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": "value"})

value = XComModel.get_value(
ti_key=task_instance.key,
key="xcom_1",
)

assert value == {"key": "value"}

@pytest.fixture
def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom):
date1 = timezone.datetime(2021, 12, 3, 4, 56)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
if TYPE_CHECKING:
import airflow.sdk.io as io # noqa: F401
import airflow.sdk.timezone as timezone # noqa: F401
from airflow.models.xcom import XCOM_RETURN_KEY as XCOM_RETURN_KEY
from airflow.models.xcom import XCOM_RETURN_KEY as XCOM_RETURN_KEY, XComModel as XComModel
from airflow.sdk import (
DAG as DAG,
Asset as Asset,
Expand Down Expand Up @@ -135,6 +135,7 @@
"AssetAlias": ("airflow.sdk", "airflow.datasets", "DatasetAlias"),
"AssetAll": ("airflow.sdk", "airflow.datasets", "DatasetAll"),
"AssetAny": ("airflow.sdk", "airflow.datasets", "DatasetAny"),
"XComModel": ("airflow.models.xcom", "airflow.models.xcom", "XCom"),
}

# Airflow 3-only renames (not available in Airflow 2)
Expand Down
7 changes: 7 additions & 0 deletions providers/common/compat/tests/unit/common/compat/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@ def test_invalid_import_raises_attribute_error():

with pytest.raises(AttributeError, match="has no attribute 'NonExistentClass'"):
_ = sdk.NonExistentClass


def test_xcom_model_imports_server_side_model():
from airflow.models.xcom import XComModel
from airflow.providers.common.compat import sdk

assert sdk.XComModel is XComModel