diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index a69a8aa205825..fed5aed537643 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``9ff64e1c35d3`` (head) | ``dd5f3a8e2b91`` | ``3.3.0`` | Add indexes on dag_run.created_dag_version_id and | +| ``623bce373cdf`` (head) | ``9ff64e1c35d3`` | ``3.3.0`` | Add is_refresh to AssetPartitionDagRun. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``9ff64e1c35d3`` | ``dd5f3a8e2b91`` | ``3.3.0`` | Add indexes on dag_run.created_dag_version_id and | | | | | task_instance.dag_version_id. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``dd5f3a8e2b91`` | ``c20871fbf23a`` | ``3.3.0`` | Add rollup_fingerprint to AssetPartitionDagRun and index | diff --git a/airflow-core/src/airflow/assets/manager.py b/airflow-core/src/airflow/assets/manager.py index c8d2edef4f104..86fc81c15ff23 100644 --- a/airflow-core/src/airflow/assets/manager.py +++ b/airflow-core/src/airflow/assets/manager.py @@ -43,6 +43,8 @@ TaskOutletAssetReference, ) from airflow.models.log import Log +from airflow.partition_mappers.base import is_rollup +from airflow.partition_mappers.rerun_policy import RerunPolicy from airflow.timetables.base import compute_rollup_fingerprint from airflow.utils.helpers import is_container, prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -643,14 +645,22 @@ def _queue_partitioned_dags( ) continue + # rerun_policy only applies to rollups (it decides what happens when an + # upstream key re-arrives after the window already fired). Non-rollup + # mappers have no such field; ``None`` keeps their legacy behavior. + rerun_policy = mapper.rerun_policy if is_rollup(mapper) else None for target_key in target_keys: apdr = cls._get_or_create_apdr( target_key=target_key, target_dag=target_dag, rollup_fingerprint=fingerprint, asset_id=asset_id, + rerun_policy=rerun_policy, session=session, ) + # IGNORE dropped this late event for an already-fired window. + if apdr is None: + continue log_record = PartitionedAssetKeyLog( asset_id=asset_id, asset_event_id=event.id, @@ -669,8 +679,9 @@ def _get_or_create_apdr( target_dag: DagModel, rollup_fingerprint: dict, asset_id: int, + rerun_policy: RerunPolicy | None = None, session: Session, - ) -> AssetPartitionDagRun: + ) -> AssetPartitionDagRun | None: """ Get or create an APDR. @@ -683,6 +694,19 @@ def _get_or_create_apdr( ``rollup_fingerprint`` is the serialized mapper / window definition for all partitioned assets in the timetable at creation time; the scheduler discards APDRs whose stamp no longer matches the current timetable's fingerprint (mapper / window may have changed). + + ``rerun_policy`` governs what happens when the latest APDR for this (key, dag) has + already fired and a new upstream event arrives (an upstream partition was cleared and + re-run). It is set only for rollup mappers; ``None`` (non-rollup) keeps the legacy + "always create a fresh APDR" behavior, where the new APDR fires on the next tick. + + - :attr:`RerunPolicy.IGNORE` returns ``None`` so the caller drops the late event. + - :attr:`RerunPolicy.REFRESH` creates a new APDR flagged ``is_refresh`` so the + scheduler fires it immediately rather than waiting for the whole window again. + - :attr:`RerunPolicy.HOLD` (and non-rollup) creates a plain new APDR that waits. + + A pending (not-yet-fired) latest APDR is always reused regardless of policy, so events + accumulating toward a window's first firing are unaffected. """ with _lock_asset_model(session=session, asset_id=asset_id): latest_apdr: AssetPartitionDagRun | None = session.scalar( @@ -703,11 +727,25 @@ def _get_or_create_apdr( ) return latest_apdr + is_refresh = False + if latest_apdr is not None: + # The latest APDR already fired, so this event re-arrives for an + # already-materialized window. Apply the rollup's rerun policy. + if rerun_policy is RerunPolicy.IGNORE: + cls.logger().debug( + "Dropping re-arrived event for fired window key %s dag_id %s (IGNORE)", + target_key, + target_dag.dag_id, + ) + return None + is_refresh = rerun_policy is RerunPolicy.REFRESH + apdr = AssetPartitionDagRun( target_dag_id=target_dag.dag_id, created_dag_run_id=None, partition_key=target_key, rollup_fingerprint=rollup_fingerprint, + is_refresh=is_refresh, ) session.add(apdr) session.flush() diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3c5d71e9a3d02..ee76d398c3858 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -2261,23 +2261,35 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st source_key_by_asset = source_key_by_asset_per_apdr[apdr.id] timetable = dag.timetable - statuses: dict[SerializedAssetUniqueKey, bool] = {} - for asset_id, (name, uri) in asset_info_per_apdr[apdr.id].items(): - key = SerializedAssetUniqueKey(name=name, uri=uri) - if timetable.partitioned: - statuses[key] = self._resolve_asset_partition_status( - session=session, - asset_id=asset_id, - name=name, - uri=uri, - apdr=apdr, - timetable=timetable, - actual_by_asset=source_key_by_asset, - ) - else: - statuses[key] = True - if not evaluator.run(timetable.asset_condition, statuses=statuses): - continue + contributing_assets = asset_info_per_apdr[apdr.id] + if apdr.is_refresh: + # A refresh APDR supersedes an already-fired window + # (RerunPolicy.REFRESH): the rest of the window is still + # materialized, so fire immediately with the re-arrived events + # rather than re-satisfying the wait policy. Still require at + # least one active contributing asset, mirroring the + # freeze-on-inactive behavior below: an APDR whose only assets + # are inactive stays pending until they reactivate. + if not contributing_assets: + continue + else: + statuses: dict[SerializedAssetUniqueKey, bool] = {} + for asset_id, (name, uri) in contributing_assets.items(): + key = SerializedAssetUniqueKey(name=name, uri=uri) + if timetable.partitioned: + statuses[key] = self._resolve_asset_partition_status( + session=session, + asset_id=asset_id, + name=name, + uri=uri, + apdr=apdr, + timetable=timetable, + actual_by_asset=source_key_by_asset, + ) + else: + statuses[key] = True + if not evaluator.run(timetable.asset_condition, statuses=statuses): + continue partition_dag_ids.add(apdr.target_dag_id) run_after = timezone.utcnow() diff --git a/airflow-core/src/airflow/migrations/versions/0123_3_3_0_add_is_refresh_to_apdr.py b/airflow-core/src/airflow/migrations/versions/0123_3_3_0_add_is_refresh_to_apdr.py new file mode 100644 index 0000000000000..7550ba4acbc9c --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0123_3_3_0_add_is_refresh_to_apdr.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add is_refresh to AssetPartitionDagRun. + +The ``is_refresh`` flag marks a provisional partition Dag run that re-fires an +already-materialized rollup window after an upstream partition was cleared and +re-run under ``RerunPolicy.REFRESH``. The scheduler fires such a run immediately +with the current events instead of waiting for the whole window to re-arrive. + +Revision ID: 623bce373cdf +Revises: 9ff64e1c35d3 +Create Date: 2026-06-17 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.utils import disable_sqlite_fkeys + +revision = "623bce373cdf" +down_revision = "9ff64e1c35d3" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """Add ``is_refresh`` to ``asset_partition_dag_run``.""" + with disable_sqlite_fkeys(op): + with op.batch_alter_table("asset_partition_dag_run", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_refresh", sa.Boolean(), nullable=False, server_default="0")) + + +def downgrade(): + """Drop the APDR ``is_refresh`` column.""" + with disable_sqlite_fkeys(op): + with op.batch_alter_table("asset_partition_dag_run", schema=None) as batch_op: + batch_op.drop_column("is_refresh") diff --git a/airflow-core/src/airflow/models/asset.py b/airflow-core/src/airflow/models/asset.py index 34e4ffaec3e38..620a7e44f8610 100644 --- a/airflow-core/src/airflow/models/asset.py +++ b/airflow-core/src/airflow/models/asset.py @@ -24,6 +24,7 @@ import sqlalchemy as sa from sqlalchemy import ( JSON, + Boolean, Column, ForeignKey, ForeignKeyConstraint, @@ -931,6 +932,11 @@ class AssetPartitionDagRun(Base): # legacy rows that pre-date the column; they are treated as stale on the # next scheduler tick. rollup_fingerprint: Mapped[dict | None] = mapped_column(JSON, nullable=True) + # Set when this APDR is a refresh of an already-materialized window (an + # upstream partition was cleared and re-run under RerunPolicy.REFRESH). The + # window already fired once, so the scheduler fires a refresh immediately + # with the current events instead of waiting for the whole window again. + is_refresh: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="0") created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False diff --git a/airflow-core/src/airflow/partition_mappers/base.py b/airflow-core/src/airflow/partition_mappers/base.py index 0be82d8f6c5bf..405c23a451c38 100644 --- a/airflow-core/src/airflow/partition_mappers/base.py +++ b/airflow-core/src/airflow/partition_mappers/base.py @@ -20,6 +20,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard +from airflow.partition_mappers.rerun_policy import RerunPolicy from airflow.partition_mappers.wait_policy import WaitForAll, WaitPolicy if TYPE_CHECKING: @@ -137,6 +138,11 @@ class RollupMapper(PartitionMapper): ``wait_policy`` that decides when the downstream Dag run fires given the expected window and the upstream keys that have actually arrived. The default policy waits for every expected upstream key. + + ``rerun_policy`` decides what happens when an upstream partition that the + fired window already consumed is cleared and re-run; see :class:`RerunPolicy`. + The default :attr:`RerunPolicy.HOLD` waits for the whole window to + re-materialize (the historical behavior). """ is_rollup: ClassVar[bool] = True @@ -147,6 +153,7 @@ def __init__( upstream_mapper: PartitionMapper, window: Window, wait_policy: WaitPolicy | None = None, + rerun_policy: RerunPolicy | str = RerunPolicy.HOLD, max_downstream_keys: int | None = None, ) -> None: decode_overridden = type(upstream_mapper).decode_downstream is not PartitionMapper.decode_downstream @@ -166,6 +173,7 @@ def __init__( self.upstream_mapper = upstream_mapper self.window = window self.wait_policy = wait_policy + self.rerun_policy = RerunPolicy(rerun_policy) def to_downstream(self, key: str) -> str | Iterable[str]: return self.upstream_mapper.to_downstream(key) @@ -198,6 +206,7 @@ def serialize(self) -> dict[str, Any]: "upstream_mapper": encode_partition_mapper(self.upstream_mapper), "window": encode_window(self.window), "wait_policy": encode_wait_policy(self.wait_policy), + "rerun_policy": self.rerun_policy.value, } if self.max_downstream_keys is not None: data["max_downstream_keys"] = self.max_downstream_keys @@ -215,6 +224,9 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: upstream_mapper=decode_partition_mapper(data["upstream_mapper"]), window=decode_window(data["window"]), wait_policy=decode_wait_policy(data["wait_policy"]), + # Default for serialized Dags written before rerun_policy existed — + # HOLD reproduces their pre-feature behavior (wait for the full window). + rerun_policy=data.get("rerun_policy", RerunPolicy.HOLD), max_downstream_keys=data.get("max_downstream_keys"), ) diff --git a/airflow-core/src/airflow/partition_mappers/rerun_policy.py b/airflow-core/src/airflow/partition_mappers/rerun_policy.py new file mode 100644 index 0000000000000..b0691b71afde3 --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/rerun_policy.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum + + +class RerunPolicy(str, Enum): + """ + Core-side mirror of the SDK :class:`airflow.sdk.RerunPolicy`. + + Decides what the scheduler does when an upstream partition is cleared and + re-run after a rollup's downstream window has already fired. See the SDK + class for the authoring-facing documentation; the two are serialized by + string value so they round-trip across the Dag-parse / scheduler boundary. + + ``HOLD`` (default): queue a provisional run that waits for the entire window + to re-materialize before firing again. This is the historical behavior, so it + is the default to keep existing Dags unchanged. + + ``REFRESH``: re-fire the downstream Dag run immediately so it reprocesses with + the corrected upstream data. + + ``IGNORE``: drop the late upstream event; do not re-fire. + """ + + REFRESH = "refresh" + HOLD = "hold" + IGNORE = "ignore" diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 59ac29f111ec7..50f4f3b3874eb 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -53,6 +53,7 @@ PartitionMapper, ProductMapper, QuarterWindow, + RerunPolicy, RollupMapper, SegmentWindow, StartOfDayMapper, @@ -530,6 +531,7 @@ def _(self, partition_mapper: RollupMapper) -> dict[str, Any]: "upstream_mapper": encode_partition_mapper(partition_mapper.upstream_mapper), "window": encode_window(partition_mapper.window), "wait_policy": encode_wait_policy(partition_mapper.wait_policy), + "rerun_policy": RerunPolicy(partition_mapper.rerun_policy).value, } if partition_mapper.max_downstream_keys is not None: data["max_downstream_keys"] = partition_mapper.max_downstream_keys diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index 400b63f3d7090..74b29ddb5b55d 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -510,6 +510,30 @@ def run_info_from_dag_run(self, *, dag_run: DagRun) -> DagRunInfo: ) +def _encode_partition_mapper_for_fingerprint(mapper: PartitionMapper) -> dict[str, Any]: + """ + Encode *mapper* for the rollup fingerprint, excluding ``rerun_policy``. + + ``rerun_policy`` only decides what happens after a window has fired, not + which upstream keys the window requires, so it must not feed the fingerprint: + including it would let a policy-only edit discard in-flight APDRs, and would + make every pending partition run look stale on upgrade (stored fingerprints + pre-date the key). It lives only on :class:`RollupMapper`, so it is dropped + only there, at its known top-level location — gating on ``is_rollup`` keeps a + custom mapper that happens to use the same field name untouched. + """ + # Local import to avoid a circular dependency: encoders.py imports Timetable + # from this module at the top level, so a top-level import here would cycle. + from airflow.partition_mappers.base import is_rollup + from airflow.serialization.encoders import encode_partition_mapper + from airflow.serialization.enums import Encoding + + encoded = encode_partition_mapper(mapper) + if is_rollup(mapper): + encoded.get(Encoding.VAR, {}).pop("rerun_policy", None) + return encoded + + def compute_rollup_fingerprint(timetable: Timetable) -> dict: """ Return the rollup-definition fingerprint for *timetable*. @@ -525,6 +549,11 @@ def compute_rollup_fingerprint(timetable: Timetable) -> dict: trigger cleanup of a stale partition Dag run, leaving unrelated Dag edits untouched. + The mapper's ``rerun_policy`` is deliberately excluded (see + :func:`_encode_partition_mapper_for_fingerprint`): it governs only post-fire + behavior, not which upstream keys a window requires, so it must not + invalidate in-flight APDRs. + Both the creation side (``assets/manager.py``) and the cleanup side (``jobs/scheduler_job_runner.py``) call this helper to guarantee the two fingerprints are computed by identical logic. @@ -532,26 +561,22 @@ def compute_rollup_fingerprint(timetable: Timetable) -> dict: if not timetable.partitioned: return {} - # Local import to avoid a circular dependency: encoders.py already imports - # Timetable from this module at the top level, so a top-level import of - # encode_partition_mapper here would create a cycle. from airflow.serialization.definitions.assets import SerializedAssetNameRef, SerializedAssetUriRef - from airflow.serialization.encoders import encode_partition_mapper entries: dict[str, dict[str, Any]] = {} for unique_key, _ in timetable.asset_condition.iter_assets(): mapper = timetable.get_partition_mapper(name=unique_key.name, uri=unique_key.uri) key = f"{unique_key.name}|{unique_key.uri}" - entries[key] = encode_partition_mapper(mapper) + entries[key] = _encode_partition_mapper_for_fingerprint(mapper) for s_asset_ref in timetable.asset_condition.iter_asset_refs(): if isinstance(s_asset_ref, SerializedAssetNameRef): mapper = timetable.get_partition_mapper(name=s_asset_ref.name) key = f"{s_asset_ref.name}|" - entries[key] = encode_partition_mapper(mapper) + entries[key] = _encode_partition_mapper_for_fingerprint(mapper) elif isinstance(s_asset_ref, SerializedAssetUriRef): mapper = timetable.get_partition_mapper(uri=s_asset_ref.uri) key = f"|{s_asset_ref.uri}" - entries[key] = encode_partition_mapper(mapper) + entries[key] = _encode_partition_mapper_for_fingerprint(mapper) return dict(sorted(entries.items())) diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 2155ca50d33b3..2863ea13aad62 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", - "3.3.0": "9ff64e1c35d3", + "3.3.0": "623bce373cdf", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 40d9401b9ffc0..243cbe2554378 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -44,6 +44,7 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.log import Log from airflow.models.team import Team +from airflow.partition_mappers.rerun_policy import RerunPolicy from airflow.partition_mappers.temporal import FanOutMapper, StartOfWeekMapper from airflow.partition_mappers.window import WeekWindow from airflow.providers.standard.operators.empty import EmptyOperator @@ -356,6 +357,96 @@ def test_queue_partitioned_dags_stamps_rollup_fingerprint(self, session, dag_mak "APDR rollup_fingerprint must match compute_rollup_fingerprint(timetable)" ) + @pytest.mark.need_serialized_dag + @pytest.mark.usefixtures("testing_dag_bundle") + @pytest.mark.parametrize( + ("rerun_policy", "expect_created", "expect_is_refresh"), + [ + pytest.param(RerunPolicy.REFRESH, True, True, id="refresh-creates-refresh-apdr"), + pytest.param(RerunPolicy.HOLD, True, False, id="hold-creates-plain-apdr"), + pytest.param(RerunPolicy.IGNORE, False, None, id="ignore-drops-event"), + pytest.param(None, True, False, id="non-rollup-keeps-legacy-behavior"), + ], + ) + def test_get_or_create_apdr_rerun_policy_after_window_fired( + self, session, dag_maker, rerun_policy, expect_created, expect_is_refresh + ): + """When the latest APDR already fired, the rerun policy decides what a re-arriving event does.""" + suffix = rerun_policy.value if rerun_policy is not None else "nonrollup" + dag_id = f"rerun-consumer-{suffix}" + asm = AssetModel(uri=f"test://rerun/{suffix}", name=f"rerun_asset_{suffix}", group="asset") + session.add(asm) + with dag_maker(dag_id=dag_id, session=session): + EmptyOperator(task_id="t") + dr = dag_maker.create_dagrun(session=session) + session.commit() + consumer = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) + + # Seed a fired APDR for ("all", consumer): a window that already materialized. + session.add( + AssetPartitionDagRun( + target_dag_id=dag_id, + partition_key="all", + created_dag_run_id=dr.id, + rollup_fingerprint={}, + ) + ) + session.commit() + + result = AssetManager._get_or_create_apdr( + target_key="all", + target_dag=consumer, + rollup_fingerprint={}, + asset_id=asm.id, + rerun_policy=rerun_policy, + session=session, + ) + total = session.scalar( + select(func.count()) + .select_from(AssetPartitionDagRun) + .where(AssetPartitionDagRun.target_dag_id == dag_id) + ) + if not expect_created: + assert result is None + assert total == 1 # only the fired APDR; the late event was dropped + else: + assert result is not None + assert result.created_dag_run_id is None + assert result.is_refresh is expect_is_refresh + assert total == 2 + + @pytest.mark.need_serialized_dag + @pytest.mark.usefixtures("testing_dag_bundle") + def test_get_or_create_apdr_reuses_pending_regardless_of_policy(self, session, dag_maker): + """A not-yet-fired APDR is always reused, so events accumulating toward the first firing are never dropped.""" + asm = AssetModel(uri="test://reuse/", name="reuse_asset", group="asset") + session.add(asm) + with dag_maker(dag_id="reuse-consumer", session=session): + EmptyOperator(task_id="t") + session.commit() + consumer = session.scalar(select(DagModel).where(DagModel.dag_id == "reuse-consumer")) + + first = AssetManager._get_or_create_apdr( + target_key="all", + target_dag=consumer, + rollup_fingerprint={}, + asset_id=asm.id, + rerun_policy=RerunPolicy.REFRESH, + session=session, + ) + # Even IGNORE reuses the still-pending APDR rather than dropping the event. + second = AssetManager._get_or_create_apdr( + target_key="all", + target_dag=consumer, + rollup_fingerprint={}, + asset_id=asm.id, + rerun_policy=RerunPolicy.IGNORE, + session=session, + ) + assert second is not None + assert second.id == first.id + assert second.is_refresh is False + @pytest.mark.usefixtures("testing_dag_bundle") def test_register_asset_change_queues_stale_dag(self, session, mock_task_instance): asset_manager = AssetManager() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index f905fccb7341e..4daa14a3ad845 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -119,6 +119,7 @@ HourWindow, IdentityMapper, MinimumCount, + RerunPolicy, RollupMapper, SegmentWindow, StartOfDayMapper, @@ -10617,6 +10618,152 @@ def test_partitioned_dag_run_rollup_holds_until_window_complete( assert partition_dags == {"rollup-consumer"} +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_refresh_refires_on_upstream_rerun( + dag_maker: DagMaker, + session: Session, +): + """ + RerunPolicy.REFRESH re-fires immediately when an upstream partition is re-run. + + After the window already fired, re-running a single upstream key creates a fresh + ``is_refresh`` APDR that fires on the next tick — it does NOT wait for the whole + window to re-materialize, because the rest of the window is still materialized. + """ + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="refresh-rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu"]), + rerun_policy=RerunPolicy.REFRESH, + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # Both regions collapse onto the single "all_regions" partition; once both are + # present the window is complete and the rollup fires once. + for region in ("us", "eu"): + apdr = _produce_and_register_asset_event( + dag_id=f"refresh-producer-{region}", + asset=asset_1, + partition_key=region, + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + first_run_id = apdr.created_dag_run_id + assert first_run_id is not None + assert partition_dags == {"refresh-rollup-consumer"} + + # Clear & re-run only "us". REFRESH creates a new, refresh APDR ... + refresh_apdr = _produce_and_register_asset_event( + dag_id="refresh-producer-us-rerun", + asset=asset_1, + partition_key="us", + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + assert refresh_apdr.id != apdr.id + assert refresh_apdr.is_refresh is True + + # ... that fires on the next tick even though "eu" did not re-arrive (1/2 of window). + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(refresh_apdr) + assert refresh_apdr.created_dag_run_id is not None + assert refresh_apdr.created_dag_run_id != first_run_id + assert partition_dags == {"refresh-rollup-consumer"} + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_hold_waits_for_full_window_on_rerun( + dag_maker: DagMaker, + session: Session, +): + """ + RerunPolicy.HOLD makes a re-run wait for the entire window to re-materialize. + + After the window fired, re-running a single upstream key creates a fresh + non-refresh APDR that holds until every key in the window arrives again. + """ + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="hold-rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + upstream_mapper=FixedKeyMapper("all_regions"), + window=SegmentWindow(["us", "eu"]), + rerun_policy=RerunPolicy.HOLD, + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + for region in ("us", "eu"): + apdr = _produce_and_register_asset_event( + dag_id=f"hold-producer-{region}", + asset=asset_1, + partition_key=region, + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + + # Re-run only "us": HOLD creates a non-refresh APDR that must wait. + hold_apdr = _produce_and_register_asset_event( + dag_id="hold-producer-us-rerun", + asset=asset_1, + partition_key="us", + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + assert hold_apdr.id != apdr.id + assert hold_apdr.is_refresh is False + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(hold_apdr) + assert hold_apdr.created_dag_run_id is None # held: only 1 / 2 of the window present + assert partition_dags == set() + + # When "eu" also re-arrives the window is complete again and the run fires. + _produce_and_register_asset_event( + dag_id="hold-producer-eu-rerun", + asset=asset_1, + partition_key="eu", + session=session, + dag_maker=dag_maker, + expected_partition_key="all_regions", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(hold_apdr) + assert hold_apdr.created_dag_run_id is not None + assert partition_dags == {"hold-rollup-consumer"} + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_partitioned_dag_run_segment_rollup_holds_until_all_segments_arrive( @@ -11525,6 +11672,39 @@ def test_partitioned_dag_run_skips_when_asset_is_inactive(dag_maker: DagMaker, s assert partition_dags == set() +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_refresh_apdr_freezes_when_asset_inactive(dag_maker: DagMaker, session: Session): + """ + A refresh APDR also freezes while its only contributing asset is inactive. + + The refresh fast-path skips the wait policy, so it must keep the same + freeze-on-inactive guard the normal path has — otherwise it would fire a run + that consumes no events. Without the ``if not contributing_assets`` guard this + APDR would fire immediately. + """ + asset = Asset(name="asset-refresh-inactive") + [apdr] = _make_n_satisfied_apdrs( + consumer_dag_id="refresh-inactive-consumer", + asset=asset, + partition_keys=["k1"], + session=session, + dag_maker=dag_maker, + ) + apdr.is_refresh = True + session.commit() + _set_asset_active(name=asset.name, uri=asset.uri, session=session, active=False) + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_partitioned_dag_run_resumes_when_asset_reactivates(dag_maker: DagMaker, session: Session): diff --git a/airflow-core/tests/unit/partition_mappers/test_base.py b/airflow-core/tests/unit/partition_mappers/test_base.py index e7b0ba02f07f6..2afc4fb30f91f 100644 --- a/airflow-core/tests/unit/partition_mappers/test_base.py +++ b/airflow-core/tests/unit/partition_mappers/test_base.py @@ -22,6 +22,7 @@ from airflow.partition_mappers.base import PartitionMapper, RollupMapper from airflow.partition_mappers.identity import IdentityMapper +from airflow.partition_mappers.rerun_policy import RerunPolicy from airflow.partition_mappers.temporal import StartOfDayMapper from airflow.partition_mappers.window import DayWindow from airflow.serialization.decoders import decode_partition_mapper @@ -180,3 +181,26 @@ def test_max_downstream_keys_absent_from_default_encoded_payload(self): mapper = RollupMapper(upstream_mapper=StartOfDayMapper(), window=DayWindow()) encoded_var = encode_partition_mapper(mapper)[Encoding.VAR] assert "max_downstream_keys" not in encoded_var + + +class TestRollupMapperRerunPolicy: + @pytest.mark.parametrize("policy", list(RerunPolicy)) + def test_encode_decode_roundtrip(self, policy): + mapper = RollupMapper(upstream_mapper=StartOfDayMapper(), window=DayWindow(), rerun_policy=policy) + encoded = encode_partition_mapper(mapper) + assert encoded[Encoding.VAR]["rerun_policy"] == policy.value + assert decode_partition_mapper(encoded).rerun_policy is policy + + def test_default_is_hold_in_encoded_payload(self): + mapper = RollupMapper(upstream_mapper=StartOfDayMapper(), window=DayWindow()) + assert encode_partition_mapper(mapper)[Encoding.VAR]["rerun_policy"] == "hold" + + def test_missing_key_decodes_to_hold(self): + """A Dag serialized before rerun_policy existed has no key; it must default to HOLD (pre-feature behavior).""" + encoded = encode_partition_mapper( + RollupMapper( + upstream_mapper=StartOfDayMapper(), window=DayWindow(), rerun_policy=RerunPolicy.REFRESH + ) + ) + del encoded[Encoding.VAR]["rerun_policy"] + assert decode_partition_mapper(encoded).rerun_policy is RerunPolicy.HOLD diff --git a/airflow-core/tests/unit/timetables/test_base_timetable.py b/airflow-core/tests/unit/timetables/test_base_timetable.py index 5aaee3db06321..d000e61f0d80f 100644 --- a/airflow-core/tests/unit/timetables/test_base_timetable.py +++ b/airflow-core/tests/unit/timetables/test_base_timetable.py @@ -119,6 +119,58 @@ def test_compute_rollup_fingerprint_window_change_produces_different_fingerprint assert fp_hour != fp_day +def test_compute_rollup_fingerprint_excludes_rerun_policy(): + """ + Changing only ``rerun_policy`` must NOT change the fingerprint. + + ``rerun_policy`` governs post-fire behavior, not which upstream keys a window + requires. If it fed the fingerprint, a policy-only edit would discard in-flight + APDRs, and upgrading from a version that pre-dates the key would invalidate + every pending partition run (stored fingerprints lack the key). + """ + from airflow.sdk import Asset, HourWindow, RerunPolicy, RollupMapper, StartOfHourMapper + from airflow.timetables.base import compute_rollup_fingerprint + from airflow.timetables.simple import PartitionedAssetTimetable + + asset_1 = Asset(name="asset-1") + + def fingerprint(policy): + tt = PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + upstream_mapper=StartOfHourMapper(), window=HourWindow(), rerun_policy=policy + ), + ) + return compute_rollup_fingerprint(tt) + + fingerprints = [fingerprint(policy) for policy in RerunPolicy] + assert all(fp == fingerprints[0] for fp in fingerprints) + assert "rerun_policy" not in str(fingerprints[0]) + + +def test_compute_rollup_fingerprint_non_rollup_mapper_unaffected(): + """ + A non-rollup partition mapper is fingerprinted unchanged. + + The ``rerun_policy`` exclusion is gated on ``is_rollup``, so a non-rollup + mapper (which never carries ``rerun_policy``) is passed through untouched — + this also guards against the exclusion ever stripping a same-named field + from some other mapper. + """ + from airflow.sdk import Asset, IdentityMapper + from airflow.serialization.encoders import encode_partition_mapper + from airflow.timetables.base import compute_rollup_fingerprint + from airflow.timetables.simple import PartitionedAssetTimetable + + asset_1 = Asset(name="asset-1") + mapper = IdentityMapper() + tt = PartitionedAssetTimetable(assets=asset_1, default_partition_mapper=mapper) + + fp = compute_rollup_fingerprint(tt) + # The entry is the full encoded mapper, identical to encoding it directly. + assert fp[f"{asset_1.name}|{asset_1.uri}"] == encode_partition_mapper(mapper) + + def test_compute_rollup_fingerprint_multi_asset_all_keys_present(): """All assets appear as keys in the fingerprint.""" from airflow.sdk import Asset, HourWindow, RollupMapper, StartOfHourMapper diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index d289628097eeb..bd0223b848993 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -241,6 +241,8 @@ Partition Mapper .. autoapiclass:: airflow.sdk.RollupMapper +.. autoapiclass:: airflow.sdk.RerunPolicy + .. autoapiclass:: airflow.sdk.WaitForAll .. autoapiclass:: airflow.sdk.MinimumCount diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 9f7fd1b048430..f753b4162b172 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -73,6 +73,7 @@ "PokeReturnValue", "ProductMapper", "QuarterWindow", + "RerunPolicy", "ResumableJobMixin", "RetryAction", "RetryDecision", @@ -164,6 +165,7 @@ from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper + from airflow.sdk.definitions.partition_mappers.rerun_policy import RerunPolicy from airflow.sdk.definitions.partition_mappers.temporal import ( FanOutMapper, StartOfDayMapper, @@ -279,6 +281,7 @@ "RetryDecision": ".definitions.retry_policy", "RetryPolicy": ".definitions.retry_policy", "RetryRule": ".definitions.retry_policy", + "RerunPolicy": ".definitions.partition_mappers.rerun_policy", "RollupMapper": ".definitions.partition_mappers.base", "SecretCache": ".execution_time.cache", "SegmentWindow": ".definitions.partition_mappers.window", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 78e3746537969..3bdfd6462a1cf 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -70,6 +70,7 @@ from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper +from airflow.sdk.definitions.partition_mappers.rerun_policy import RerunPolicy from airflow.sdk.definitions.partition_mappers.temporal import ( FanOutMapper, StartOfDayMapper, @@ -177,6 +178,7 @@ __all__ = [ "PartitionMapper", "ProductMapper", "QuarterWindow", + "RerunPolicy", "RetryAction", "RetryDecision", "RetryPolicy", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py index 079b85fc02105..706b074ddc757 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py @@ -20,6 +20,7 @@ import attrs +from airflow.sdk.definitions.partition_mappers.rerun_policy import RerunPolicy from airflow.sdk.definitions.partition_mappers.wait_policy import WaitForAll, WaitPolicy if TYPE_CHECKING: @@ -67,6 +68,12 @@ class RollupMapper(PartitionMapper): ``MinimumCount(n)`` fires once at least ``n`` keys have arrived when ``n`` is positive, or once at most ``-n`` keys are still missing when ``n`` is negative. + + The ``rerun_policy`` is a :class:`RerunPolicy` that decides what happens when + an upstream partition is cleared and re-run after the downstream window has + already fired. The default ``RerunPolicy.HOLD`` waits for the whole window to + re-materialize (the historical behavior); ``RerunPolicy.REFRESH`` re-fires + immediately with the corrected data. """ is_rollup: ClassVar[bool] = True @@ -74,6 +81,7 @@ class RollupMapper(PartitionMapper): upstream_mapper: PartitionMapper = attrs.field(kw_only=True) window: Window = attrs.field(kw_only=True) wait_policy: WaitPolicy = attrs.field(factory=WaitForAll, kw_only=True) + rerun_policy: RerunPolicy = attrs.field(default=RerunPolicy.HOLD, kw_only=True, converter=RerunPolicy) def __attrs_post_init__(self) -> None: # Mirrors the core-side ``RollupMapper.__init__`` check so user code diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/rerun_policy.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/rerun_policy.py new file mode 100644 index 0000000000000..7c1d053ba02a1 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/rerun_policy.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum + + +class RerunPolicy(str, Enum): + """ + How a rollup reacts when an upstream partition is re-emitted after its downstream window already fired. + + A rollup fires its downstream Dag run once the window is satisfied (e.g. a + monthly rollup fires once all of March's daily partitions arrive). If an + upstream partition that the fired window already consumed is later cleared + and re-run, a fresh asset event arrives for an already-materialized window. + This policy decides what the scheduler does with it. + + ``HOLD`` (default): queue a provisional run that waits for the *entire* + window to re-materialize before firing again. A single re-run does not + re-fire; only a full recompute of the window does. This is the historical + behavior of a rollup before this policy existed, so it is the default to keep + existing Dags unchanged. + + ``REFRESH``: re-fire the downstream Dag run so it reprocesses with the + corrected upstream data. The rest of the window is still materialized, so the + refresh run fires immediately rather than waiting for the whole window to + re-arrive. This mirrors how a non-partitioned asset-triggered Dag re-runs on + every new asset event. + + ``IGNORE``: drop the late upstream event. The downstream Dag run is not + re-fired and no provisional run is queued. + """ + + REFRESH = "refresh" + HOLD = "hold" + IGNORE = "ignore" diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py index ca811115cb72a..7c44fd190bdad 100644 --- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py +++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py @@ -24,6 +24,7 @@ from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper +from airflow.sdk.definitions.partition_mappers.rerun_policy import RerunPolicy from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper from airflow.sdk.definitions.partition_mappers.window import ( DayWindow, @@ -67,6 +68,35 @@ class _AlphaWindow(Window): RollupMapper(upstream_mapper=_StringOnlyMapper(), window=_AlphaWindow()) +class TestSdkRollupMapperRerunPolicy: + """``rerun_policy`` defaults to HOLD, coerces strings, and rejects junk.""" + + def _mapper(self, **kwargs): + return RollupMapper(upstream_mapper=StartOfDayMapper(), window=DayWindow(), **kwargs) + + def test_defaults_to_hold(self): + assert self._mapper().rerun_policy is RerunPolicy.HOLD + + @pytest.mark.parametrize("policy", list(RerunPolicy)) + def test_accepts_enum_member(self, policy): + assert self._mapper(rerun_policy=policy).rerun_policy is policy + + @pytest.mark.parametrize( + ("value", "expected"), + [ + ("refresh", RerunPolicy.REFRESH), + ("hold", RerunPolicy.HOLD), + ("ignore", RerunPolicy.IGNORE), + ], + ) + def test_coerces_string(self, value, expected): + assert self._mapper(rerun_policy=value).rerun_policy is expected + + def test_rejects_unknown_value(self): + with pytest.raises(ValueError, match="reprocess"): + self._mapper(rerun_policy="reprocess") + + class TestSdkDirectionValidation: """SDK Window.__init__ must coerce valid strings and reject invalid ones at construction time."""