From cee34764a366b170f5dd415fff4aa68532d1c896 Mon Sep 17 00:00:00 2001 From: Dhilip Subramanian <49802211+sdhilip200@users.noreply.github.com> Date: Sun, 21 Jun 2026 14:26:29 +1200 Subject: [PATCH] Add server-side XComModel value helper --- airflow-core/src/airflow/models/xcom.py | 47 +++++++++++++++++++ .../routes/public/test_extra_links.py | 32 +++++++++++++ airflow-core/tests/unit/models/test_xcom.py | 10 ++++ .../airflow/providers/common/compat/sdk.py | 3 +- .../tests/unit/common/compat/test_sdk.py | 7 +++ 5 files changed, 98 insertions(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 5b27f244bf517..7681d1aa95098 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -54,6 +54,8 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause + from airflow.models.taskinstancekey import TaskInstanceKey + XCOM_RETURN_KEY = "return_value" @@ -332,6 +334,51 @@ def get_many( return query.limit(limit) return query + @classmethod + @provide_session + def get_one( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + ) -> Any | None: + """Retrieve and deserialize a single XCom value from the metadata database.""" + result = session.execute( + cls.get_many( + key=key, + dag_ids=dag_id, + task_ids=task_id, + run_id=run_id, + map_indexes=map_index, + include_prior_dates=include_prior_dates, + limit=1, + ).with_only_columns(cls.value) + ).first() + if result is None: + return None + return cls.deserialize_value(result) + + @classmethod + def get_value( + cls, + *, + ti_key: TaskInstanceKey, + key: str, + ) -> Any | None: + """Retrieve and deserialize an XCom value for a task instance key.""" + return cls.get_one( + key=key, + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + ) + @staticmethod def serialize_value( value: Any, diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py index 6816f330cfbe6..0f0ef8d7d5f93 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py @@ -72,6 +72,21 @@ class TryNumberPlugin(AirflowPlugin): operator_extra_links = [TryNumberLink()] +class XComModelLink(BaseOperatorLink): + name = "XCom Model" + operators = [CustomOperator] + + def get_link(self, operator, *, ti_key): + from airflow.providers.common.compat.sdk import XComModel + + return XComModel.get_value(key="plugin_link", ti_key=ti_key) or "" + + +class XComModelPlugin(AirflowPlugin): + name = "xcom_model_plugin" + operator_extra_links = [XComModelLink()] + + @pytest.mark.mock_plugin_manager(plugins=[]) class TestGetExtraLinks: dag_id = "TEST_DAG_ID" @@ -265,6 +280,23 @@ def test_should_respond_200_support_plugins(self, test_client): ).model_dump() ) + @pytest.mark.mock_plugin_manager(plugins=[XComModelPlugin]) + def test_plugin_link_can_read_xcom_with_xcommodel(self, test_client): + XCom.set( + key="plugin_link", + value="https://example.com/plugin-link", + task_id=self.task_single_link, + dag_id=self.dag_id, + run_id=self.dag_run_id, + ) + + response = test_client.get( + f"/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_single_link}/links", + ) + + assert response.status_code == 200 + assert response.json()["extra_links"]["XCom Model"] == "https://example.com/plugin-link" + def test_should_respond_200_mapped_task_instance(self, test_client, session): for map_index, value in enumerate(["TEST_LINK_VALUE_1", "TEST_LINK_VALUE_2"]): XCom.set( diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 2635aac1803dd..97ce33fcddbae 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -240,6 +240,16 @@ def test_xcom_get_one(self, session, task_instance): ).first() assert XComModel.deserialize_value(stored_value) == {"key": "value"} + def test_xcommodel_get_value(self, task_instance, push_simple_json_xcom): + push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": "value"}) + + value = XComModel.get_value( + ti_key=task_instance.key, + key="xcom_1", + ) + + assert value == {"key": "value"} + @pytest.fixture def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory, push_simple_json_xcom): date1 = timezone.datetime(2021, 12, 3, 4, 56) diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py b/providers/common/compat/src/airflow/providers/common/compat/sdk.py index 93174df7b2a28..88b7edb350ad1 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py +++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: import airflow.sdk.io as io # noqa: F401 import airflow.sdk.timezone as timezone # noqa: F401 - from airflow.models.xcom import XCOM_RETURN_KEY as XCOM_RETURN_KEY + from airflow.models.xcom import XCOM_RETURN_KEY as XCOM_RETURN_KEY, XComModel as XComModel from airflow.sdk import ( DAG as DAG, Asset as Asset, @@ -135,6 +135,7 @@ "AssetAlias": ("airflow.sdk", "airflow.datasets", "DatasetAlias"), "AssetAll": ("airflow.sdk", "airflow.datasets", "DatasetAll"), "AssetAny": ("airflow.sdk", "airflow.datasets", "DatasetAny"), + "XComModel": ("airflow.models.xcom", "airflow.models.xcom", "XCom"), } # Airflow 3-only renames (not available in Airflow 2) diff --git a/providers/common/compat/tests/unit/common/compat/test_sdk.py b/providers/common/compat/tests/unit/common/compat/test_sdk.py index 08b5107e4930a..f6e5329f07d86 100644 --- a/providers/common/compat/tests/unit/common/compat/test_sdk.py +++ b/providers/common/compat/tests/unit/common/compat/test_sdk.py @@ -52,3 +52,10 @@ def test_invalid_import_raises_attribute_error(): with pytest.raises(AttributeError, match="has no attribute 'NonExistentClass'"): _ = sdk.NonExistentClass + + +def test_xcom_model_imports_server_side_model(): + from airflow.models.xcom import XComModel + from airflow.providers.common.compat import sdk + + assert sdk.XComModel is XComModel