From abb3dcab517a7ee97ed7e24af644c1fc95d88945 Mon Sep 17 00:00:00 2001 From: PoAn Yang Date: Sat, 20 Jun 2026 18:29:10 +0900 Subject: [PATCH] Add SQS shared-stream producer and trigger Many event-driven triggers and asset watchers can watch the same SQS queue. Without sharing, each opens its own consumer, multiplying ReceiveMessage calls and in-flight load against a single queue. The producer-side ack channel (#67523) lets sibling triggers share one consumer in the triggerer, with message deletion gated on trigger-event persistence -- so a triggerer crash leaves an unconfirmed message to be redelivered rather than dropped (at-least-once). This adds the SQS backend for that ack channel. Signed-off-by: PoAn Yang --- .../provider_dependencies.json.sha256sum | 2 +- providers/amazon/provider.yaml | 1 + .../amazon/aws/triggers/shared_stream.py | 397 +++++++++++++++++ .../providers/amazon/get_provider_info.py | 5 +- .../amazon/aws/triggers/test_shared_stream.py | 398 ++++++++++++++++++ 5 files changed, 801 insertions(+), 2 deletions(-) create mode 100644 providers/amazon/src/airflow/providers/amazon/aws/triggers/shared_stream.py create mode 100644 providers/amazon/tests/unit/amazon/aws/triggers/test_shared_stream.py diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum index 943fd0fc93e4c..dc286285740bd 100644 --- a/generated/provider_dependencies.json.sha256sum +++ b/generated/provider_dependencies.json.sha256sum @@ -1 +1 @@ -93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b +d951411b2236516f9714c53ec8da1083fe33931a17de45be60791c26c1c37c68 diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 2ebbef397176b..4e2c96146e0f6 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -865,6 +865,7 @@ triggers: - integration-name: Amazon Simple Queue Service (SQS) python-modules: - airflow.providers.amazon.aws.triggers.sqs + - airflow.providers.amazon.aws.triggers.shared_stream - integration-name: AWS Glue DataBrew python-modules: - airflow.providers.amazon.aws.triggers.glue_databrew diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/shared_stream.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/shared_stream.py new file mode 100644 index 0000000000000..3a38dacb164f0 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/shared_stream.py @@ -0,0 +1,397 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared-stream SQS trigger and producer for event-driven scheduling. + +Triggers that declare the same ``sqs_queue`` + ``aws_conn_id`` share a single +SQS consumer in the triggerer (one receive loop broadcast to every subscriber) +instead of polling the queue once each. The :class:`SqsSharedStreamProducer` +owns that consumer and deletes a message only after the derived +:class:`~airflow.triggers.base.TriggerEvent` instances have been persisted, via +the shared-stream ack channel. + +SQS is explicit-delete: ``ReceiveMessage`` only makes a message invisible for +the visibility timeout, never removing it, so an unconfirmed message is +redelivered rather than lost. Deleting only after the event is persisted -- +never on receipt -- is what keeps a triggerer crash from dropping a message the +queue still considers in-flight. + +Requires an Airflow version whose ``airflow.triggers.shared_stream`` module +provides the producer-side ack channel. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncIterator, Hashable, Iterator, Sequence +from contextlib import AsyncExitStack, suppress +from typing import Any, NamedTuple + +from airflow.providers.amazon.aws.hooks.sqs import SqsHook +from airflow.triggers.base import BaseEventTrigger, TriggerEvent +from airflow.triggers.shared_stream import AdvanceItem, SharedStreamProducer + +log = logging.getLogger(__name__) + +# SQS caps batch entries (delete / change-visibility) at 10 per request. +_SQS_BATCH_LIMIT = 10 + + +def _in_batches(items: Sequence[Any], size: int) -> Iterator[Sequence[Any]]: + for start in range(0, len(items), size): + yield items[start : start + size] + + +class SqsBrokerPayload(NamedTuple): + """ + The ``broker_payload`` carried alongside each raw event from ``open_stream``. + + ``receipt_handle`` is what :meth:`SqsSharedStreamProducer.advance` deletes + with and what the background renewer extends visibility with; it is bound at + receive time and is not changed by visibility extension. ``message_id`` is + the message's stable identity (unchanged across redelivery) kept for logging + and tracing. ``group_id`` is the FIFO ``MessageGroupId`` used as the advance + lane key; it is ``None`` for standard queues. + """ + + receipt_handle: str + message_id: str + group_id: str | None + + +class SqsSharedStreamProducer(SharedStreamProducer): + """ + Broker-side half of a shared SQS stream running in ack mode. + + Drives one SQS consumer for a shared-stream group and deletes a message + only after every subscriber that derived a ``TriggerEvent`` from it has had + that event persisted -- the ack channel gates the delete, so a triggerer + crash cannot drop a message the queue already considers in-flight. + + ``ReceiveMessage`` only makes a message invisible for the visibility timeout + and never deletes it, so SQS is at-least-once by default. This producer keeps + it that way by deleting exclusively from :meth:`advance`, never on receipt. + + While a message is outstanding (received, not yet resolved) a background + task extends its visibility timeout before it lapses, so a slow subscriber + does not cause SQS to redeliver a message that is still being processed. + The visibility timeout therefore only bounds redelivery latency after a + triggerer crash (when renewal stops), not how long a subscriber may take. + + Events a subscriber **rejects** (via ``reject_shared_stream_event``) are + terminal and dropped (deleted without redelivery). Involuntary failures + (ack timeout / overflow) reset the message's visibility to zero so SQS + redelivers it immediately instead of waiting out the visibility timeout. A + broadcast no subscriber was online for is left to lapse and be redelivered + once a subscriber returns. Configure a redrive policy on the queue to + dead-letter messages that fail repeatedly. + + :param sqs_queue: URL of the SQS queue the shared consumer receives from. + :param aws_conn_id: AWS connection id, defaults to ``aws_default``. + :param max_messages: Maximum messages to fetch per ``ReceiveMessage`` (1-10). + :param wait_time_seconds: Long-poll wait per ``ReceiveMessage`` call. + :param visibility_timeout: Visibility timeout applied at receive time and on + every renewal. Renewal keeps the message invisible while it is + outstanding, so this bounds redelivery latency after a crash, not + subscriber processing time. + :param region_name: AWS region of the queue. + :param verify: ``botocore`` TLS verification flag passed to the hook. + :param botocore_config: Optional ``botocore`` client config dict. + """ + + def __init__( + self, + *, + sqs_queue: str, + aws_conn_id: str | None = "aws_default", + max_messages: int = 10, + wait_time_seconds: int = 20, + visibility_timeout: int = 120, + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, + ) -> None: + self.sqs_queue = sqs_queue + self.aws_conn_id = aws_conn_id + self.max_messages = max_messages + self.wait_time_seconds = wait_time_seconds + self.visibility_timeout = visibility_timeout + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config + self._exit_stack: AsyncExitStack | None = None + self._client: Any = None + self._renew_task: asyncio.Task | None = None + # Outstanding messages keyed by receipt handle: added on yield, removed + # when advance resolves them. The renewer extends visibility for these. + self._outstanding: dict[str, SqsBrokerPayload] = {} + + @property + def hook(self) -> SqsHook: + return SqsHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + async def open_stream(self) -> AsyncIterator[tuple[Any, SqsBrokerPayload]]: + """Open the consumer lazily and yield (body, SqsBrokerPayload) per message.""" + self._exit_stack = AsyncExitStack() + self._client = await self._exit_stack.enter_async_context(await self.hook.get_async_conn()) + self._renew_task = asyncio.create_task(self._renew_visibility_loop()) + while True: + response = await self._client.receive_message( + QueueUrl=self.sqs_queue, + MaxNumberOfMessages=self.max_messages, + WaitTimeSeconds=self.wait_time_seconds, + VisibilityTimeout=self.visibility_timeout, + # Standard queues have no MessageGroupId; SQS simply omits it from + # the response rather than erroring, so this is safe to request + # unconditionally and lets one code path serve standard and FIFO. + MessageSystemAttributeNames=["MessageGroupId"], + ) + for message in response.get("Messages", []): + payload = SqsBrokerPayload( + receipt_handle=message["ReceiptHandle"], + message_id=message["MessageId"], + group_id=message.get("Attributes", {}).get("MessageGroupId"), + ) + self._outstanding[payload.receipt_handle] = payload + yield message.get("Body"), payload + + def get_advance_lane(self, broker_payload: SqsBrokerPayload) -> Hashable: + """ + Order advances per message group, falling back to per-message. + + A FIFO queue guarantees order within a ``MessageGroupId``, so messages + sharing one must advance in fan-out order -- the group is the lane. A + standard queue has no ordering and its deletes are independent, so each + message gets its own lane (its receipt handle) for maximum parallelism + with no head-of-line blocking. + """ + return broker_payload.group_id or broker_payload.receipt_handle + + async def advance(self, batch: Sequence[AdvanceItem]) -> None: + """ + Resolve each message in the batch independently. + + Deletes do not accumulate, so each item is handled on its own: + + * ``acked`` -- accepted and persisted; delete it. + * ``rejected`` -- terminally refused; delete it (dropped, not redelivered). + * ``failed`` -- an involuntary failure (ack timeout / overflow); reset its + visibility to zero so SQS redelivers it immediately, rather than waiting + out the remaining visibility timeout. + * all-zero (a broadcast no subscriber was online for) -- leave it to lapse + and be redelivered once a subscriber returns. Unlike a failure it is not + reset to zero: an immediate redelivery would busy-loop while no + subscriber is online, so letting it lapse gives one time to return. + + Every item is removed from the outstanding set first so the renewer stops + extending its visibility. A delete or visibility reset the broker reports + as failed is logged, not raised: the message simply lapses and is + redelivered, preserving at-least-once. + """ + to_delete: list[SqsBrokerPayload] = [] + to_redeliver: list[SqsBrokerPayload] = [] + for item in batch: + payload = item.broker_payload + outcome = item.outcome + self._outstanding.pop(payload.receipt_handle, None) + if outcome.failed: + to_redeliver.append(payload) + elif outcome.acked == 0 and outcome.rejected == 0: + continue + else: + to_delete.append(payload) + if to_delete: + await self._delete(to_delete) + if to_redeliver: + await self._redeliver(to_redeliver) + + async def _delete(self, payloads: list[SqsBrokerPayload]) -> None: + for chunk in _in_batches(payloads, _SQS_BATCH_LIMIT): + entries = [ + {"Id": str(index), "ReceiptHandle": payload.receipt_handle} + for index, payload in enumerate(chunk) + ] + response = await self._client.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries) + failed = response.get("Failed", []) + if failed: + log.warning( + "Failed to delete %d SQS message(s); they will be redelivered: %s", len(failed), failed + ) + + async def _redeliver(self, payloads: list[SqsBrokerPayload]) -> None: + """Reset visibility to zero so SQS redelivers the failed messages immediately.""" + for chunk in _in_batches(payloads, _SQS_BATCH_LIMIT): + entries = [ + {"Id": str(index), "ReceiptHandle": payload.receipt_handle, "VisibilityTimeout": 0} + for index, payload in enumerate(chunk) + ] + try: + await self._client.change_message_visibility_batch(QueueUrl=self.sqs_queue, Entries=entries) + except Exception: + log.warning( + "Failed to reset SQS visibility for %d message(s); they will lapse instead", + len(entries), + exc_info=True, + ) + + async def _renew_visibility_loop(self) -> None: + interval = max(1, self.visibility_timeout // 2) + while True: + await asyncio.sleep(interval) + await self._renew_visibility() + + async def _renew_visibility(self) -> None: + payloads = list(self._outstanding.values()) + for chunk in _in_batches(payloads, _SQS_BATCH_LIMIT): + entries = [ + { + "Id": str(index), + "ReceiptHandle": payload.receipt_handle, + "VisibilityTimeout": self.visibility_timeout, + } + for index, payload in enumerate(chunk) + ] + try: + await self._client.change_message_visibility_batch(QueueUrl=self.sqs_queue, Entries=entries) + except Exception: + log.warning("Failed to extend SQS visibility for %d message(s)", len(entries), exc_info=True) + + async def aclose(self) -> None: + """Cancel the renewer and close the SQS client when the poll ends; best-effort.""" + task = self._renew_task + self._renew_task = None + if task is not None: + task.cancel() + with suppress(asyncio.CancelledError): + await task + stack = self._exit_stack + self._exit_stack = None + if stack is not None: + try: + await stack.aclose() + except Exception: + log.warning("Failed to close SQS client", exc_info=True) + self._client = None + self._outstanding.clear() + + +class SqsSharedStreamTrigger(BaseEventTrigger): + """ + Event-driven trigger that watches an SQS queue through a shared consumer. + + Triggers that declare the same ``sqs_queue`` + ``aws_conn_id`` share one + underlying SQS consumer in the triggerer (a single receive loop broadcast + to every subscriber). Each subscriber fires a ``TriggerEvent`` per message; + override :meth:`filter_shared_stream` to fire only for the messages this + trigger cares about. + + Designed to back an :class:`~airflow.sdk.AssetWatcher` for event-driven + scheduling. A message is deleted only after the derived ``TriggerEvent`` is + persisted -- never on receipt; see :class:`SqsSharedStreamProducer` for the + at-least-once delete semantics. + + :param sqs_queue: URL of the SQS queue to watch. + :param aws_conn_id: AWS connection id, defaults to ``aws_default``. + :param max_messages: Maximum messages to fetch per ``ReceiveMessage`` (1-10). + :param wait_time_seconds: Long-poll wait per ``ReceiveMessage`` call. + :param visibility_timeout: Visibility timeout applied at receive time and on + every renewal; see :class:`SqsSharedStreamProducer`. + :param region_name: AWS region of the queue. + :param verify: ``botocore`` TLS verification flag passed to the hook. + :param botocore_config: Optional ``botocore`` client config dict. + """ + + def __init__( + self, + *, + sqs_queue: str, + aws_conn_id: str | None = "aws_default", + max_messages: int = 10, + wait_time_seconds: int = 20, + visibility_timeout: int = 120, + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, + ) -> None: + super().__init__() + self.sqs_queue = sqs_queue + self.aws_conn_id = aws_conn_id + self.max_messages = max_messages + self.wait_time_seconds = wait_time_seconds + self.visibility_timeout = visibility_timeout + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.amazon.aws.triggers.shared_stream.SqsSharedStreamTrigger", + { + "sqs_queue": self.sqs_queue, + "aws_conn_id": self.aws_conn_id, + "max_messages": self.max_messages, + "wait_time_seconds": self.wait_time_seconds, + "visibility_timeout": self.visibility_timeout, + "region_name": self.region_name, + "verify": self.verify, + "botocore_config": self.botocore_config, + }, + ) + + def shared_stream_key(self) -> Hashable: + """Triggers on the same queue + connection share one consumer.""" + return ("sqs-shared-stream", self.sqs_queue, self.aws_conn_id, self.region_name) + + @classmethod + def create_shared_stream_producer(cls, kwargs: dict[str, Any]) -> SqsSharedStreamProducer: + return SqsSharedStreamProducer( + sqs_queue=kwargs["sqs_queue"], + aws_conn_id=kwargs["aws_conn_id"], + max_messages=kwargs["max_messages"], + wait_time_seconds=kwargs["wait_time_seconds"], + visibility_timeout=kwargs["visibility_timeout"], + region_name=kwargs.get("region_name"), + verify=kwargs.get("verify"), + botocore_config=kwargs.get("botocore_config"), + ) + + async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) -> AsyncIterator[TriggerEvent]: + """Fire one ``TriggerEvent`` per message body. Override to filter or transform.""" + async for body in shared_stream: + yield TriggerEvent(body) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Not supported -- this trigger runs only through the shared-stream manager. + + ``shared_stream_key`` always returns non-``None``, so the triggerer + drives this trigger through :meth:`filter_shared_stream`; the shared + group owns the SQS consumer and message deletes. There is no standalone + path: deleting safely needs the ack channel to gate it on trigger-event + persistence, which only the manager provides. + """ + raise NotImplementedError( + "SqsSharedStreamTrigger runs only through the triggerer's shared-stream " + "manager (via filter_shared_stream); it has no standalone run() path." + ) + yield # pragma: no cover - marks this as an async generator diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 3a407bb1b2fa6..fcc3886a3ae0f 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -983,7 +983,10 @@ def get_provider_info(): }, { "integration-name": "Amazon Simple Queue Service (SQS)", - "python-modules": ["airflow.providers.amazon.aws.triggers.sqs"], + "python-modules": [ + "airflow.providers.amazon.aws.triggers.sqs", + "airflow.providers.amazon.aws.triggers.shared_stream", + ], }, { "integration-name": "AWS Glue DataBrew", diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_shared_stream.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_shared_stream.py new file mode 100644 index 0000000000000..75158e3cef5e5 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_shared_stream.py @@ -0,0 +1,398 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +pytest.importorskip("airflow.triggers.shared_stream") + +from airflow.providers.amazon.aws.hooks.sqs import SqsHook +from airflow.providers.amazon.aws.triggers.shared_stream import ( + SqsBrokerPayload, + SqsSharedStreamProducer, + SqsSharedStreamTrigger, +) +from airflow.triggers.base import TriggerEvent +from airflow.triggers.shared_stream import AdvanceItem, AdvanceOutcome + + +class _FakeSqsClient: + """A minimal stand-in for an aiobotocore SQS client (also its own async CM).""" + + def __init__(self, message_batches=(), delete_failed=()): + self._batches = list(message_batches) + self._delete_failed = list(delete_failed) + self.receive_kwargs: list[dict] = [] + self.deleted_entries: list[list[dict]] = [] + self.visibility_entries: list[list[dict]] = [] + self.closed = False + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + self.closed = True + return False + + async def receive_message(self, **kwargs): + self.receive_kwargs.append(kwargs) + if self._batches: + return {"Messages": self._batches.pop(0)} + return {"Messages": []} + + async def delete_message_batch(self, **kwargs): + self.deleted_entries.append(list(kwargs["Entries"])) + return {"Successful": list(kwargs["Entries"]), "Failed": list(self._delete_failed)} + + async def change_message_visibility_batch(self, **kwargs): + self.visibility_entries.append(list(kwargs["Entries"])) + return {"Successful": list(kwargs["Entries"])} + + +def _msg(message_id, receipt_handle=None, body="body", group_id=None): + message = { + "MessageId": message_id, + "ReceiptHandle": receipt_handle or f"rh-{message_id}", + "Body": body, + } + if group_id is not None: + message["Attributes"] = {"MessageGroupId": group_id} + return message + + +def _payload(message_id, receipt_handle=None, group_id=None): + return SqsBrokerPayload( + receipt_handle=receipt_handle or f"rh-{message_id}", + message_id=message_id, + group_id=group_id, + ) + + +def _item(message_id, outcome): + return AdvanceItem(_payload(message_id), outcome) + + +def _acked(message_id): + return _item(message_id, AdvanceOutcome(acked=1, failed=0)) + + +def _rejected(message_id): + return _item(message_id, AdvanceOutcome(acked=0, failed=0, rejected=1)) + + +def _failed(message_id): + return _item(message_id, AdvanceOutcome(acked=0, failed=1)) + + +def _zero_subscriber(message_id): + return _item(message_id, AdvanceOutcome(acked=0, failed=0)) + + +def _outstanding_from(items): + return {item.broker_payload.receipt_handle: item.broker_payload for item in items} + + +class TestSqsSharedStreamProducer: + @pytest.mark.asyncio + @mock.patch.object(SqsHook, "get_async_conn", new_callable=AsyncMock) + async def test_open_stream_yields_body_and_payload_and_tracks_outstanding(self, mock_get_async_conn): + client = _FakeSqsClient([[_msg("m1", "rh1", "hello")]]) + mock_get_async_conn.return_value = client + + producer = SqsSharedStreamProducer(sqs_queue="q", visibility_timeout=30) + stream = producer.open_stream() + body, payload = await anext(stream) + + assert body == "hello" + assert payload == SqsBrokerPayload(receipt_handle="rh1", message_id="m1", group_id=None) + assert producer._outstanding == {"rh1": payload} + # The receive call carries the configured visibility timeout so the message + # stays invisible from the first moment, and requests MessageGroupId so a + # FIFO queue can be ordered by group. + assert client.receive_kwargs[0]["VisibilityTimeout"] == 30 + assert client.receive_kwargs[0]["MessageSystemAttributeNames"] == ["MessageGroupId"] + # Opening the stream starts the background visibility renewer. + assert producer._renew_task is not None + + await producer.aclose() + await stream.aclose() + + @pytest.mark.asyncio + @mock.patch.object(SqsHook, "get_async_conn", new_callable=AsyncMock) + async def test_open_stream_captures_group_id_for_fifo(self, mock_get_async_conn): + client = _FakeSqsClient([[_msg("m1", "rh1", "hello", group_id="orders")]]) + mock_get_async_conn.return_value = client + + producer = SqsSharedStreamProducer(sqs_queue="q.fifo") + stream = producer.open_stream() + _body, payload = await anext(stream) + + # FIFO messages carry MessageGroupId, which becomes the advance lane. + assert payload.group_id == "orders" + assert producer.get_advance_lane(payload) == "orders" + + await producer.aclose() + await stream.aclose() + + @pytest.mark.parametrize( + ("group_id", "expected_lane"), + [ + pytest.param(None, "rh-m1", id="standard-falls-back-to-receipt-handle"), + pytest.param("orders", "orders", id="fifo-uses-message-group-id"), + ], + ) + def test_get_advance_lane_prefers_group_then_receipt_handle(self, group_id, expected_lane): + producer = SqsSharedStreamProducer(sqs_queue="q") + assert producer.get_advance_lane(_payload("m1", group_id=group_id)) == expected_lane + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("item_factory", "should_delete", "should_redeliver"), + [ + pytest.param(_acked, True, False, id="acked-deleted"), + pytest.param(_rejected, True, False, id="rejected-dropped-by-delete"), + pytest.param(_failed, False, True, id="failed-reset-to-zero-for-immediate-redelivery"), + pytest.param(_zero_subscriber, False, False, id="zero-subscriber-left-to-lapse"), + ], + ) + async def test_advance_resolves_each_outcome_and_always_clears_outstanding( + self, item_factory, should_delete, should_redeliver + ): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + item = item_factory("m1") + producer._outstanding = _outstanding_from([item]) + + await producer.advance([item]) + + deleted = [entry["ReceiptHandle"] for batch in client.deleted_entries for entry in batch] + assert (deleted == ["rh-m1"]) is should_delete + # A failed message is reset to visibility 0 for immediate redelivery; a + # zero-subscriber message is left to lapse instead (no visibility change). + reset = [ + entry["ReceiptHandle"] + for batch in client.visibility_entries + for entry in batch + if entry["VisibilityTimeout"] == 0 + ] + assert (reset == ["rh-m1"]) is should_redeliver + # Every item is removed so the renewer stops extending it. + assert producer._outstanding == {} + + @pytest.mark.asyncio + async def test_advance_mixed_batch_deletes_accepted_and_redelivers_failed(self): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + items = [_acked("a"), _failed("b"), _rejected("c"), _zero_subscriber("d")] + producer._outstanding = _outstanding_from(items) + + await producer.advance(items) + + deleted = {entry["ReceiptHandle"] for batch in client.deleted_entries for entry in batch} + assert deleted == {"rh-a", "rh-c"} + reset = { + entry["ReceiptHandle"] + for batch in client.visibility_entries + for entry in batch + if entry["VisibilityTimeout"] == 0 + } + # Only the failed message is reset for immediate redelivery; the + # zero-subscriber one (d) is left to lapse. + assert reset == {"rh-b"} + assert producer._outstanding == {} + + @pytest.mark.asyncio + async def test_advance_batches_deletes_in_groups_of_ten(self): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + items = [_acked(f"m{i}") for i in range(21)] + producer._outstanding = _outstanding_from(items) + + await producer.advance(items) + + assert [len(batch) for batch in client.deleted_entries] == [10, 10, 1] + + @pytest.mark.asyncio + async def test_advance_delete_failure_is_not_raised(self): + client = _FakeSqsClient(delete_failed=[{"Id": "0", "SenderFault": True, "Code": "x"}]) + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + item = _acked("a") + producer._outstanding = _outstanding_from([item]) + + await producer.advance([item]) + + assert client.deleted_entries # delete was attempted; reported failure is swallowed + + @pytest.mark.asyncio + async def test_advance_batches_redeliveries_in_groups_of_ten(self): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + items = [_failed(f"m{i}") for i in range(21)] + producer._outstanding = _outstanding_from(items) + + await producer.advance(items) + + assert [len(batch) for batch in client.visibility_entries] == [10, 10, 1] + assert all(entry["VisibilityTimeout"] == 0 for batch in client.visibility_entries for entry in batch) + + @pytest.mark.asyncio + async def test_advance_redeliver_failure_is_not_raised(self): + client = _FakeSqsClient() + client.change_message_visibility_batch = AsyncMock(side_effect=RuntimeError("boom")) + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + item = _failed("a") + producer._outstanding = _outstanding_from([item]) + + await producer.advance([item]) # no raise; the message lapses instead + + assert producer._outstanding == {} + + @pytest.mark.asyncio + async def test_renew_visibility_extends_all_outstanding(self): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q", visibility_timeout=100) + producer._client = client + producer._outstanding = _outstanding_from([_acked("m1"), _acked("m2")]) + + await producer._renew_visibility() + + entries = client.visibility_entries[0] + assert {entry["ReceiptHandle"] for entry in entries} == {"rh-m1", "rh-m2"} + assert all(entry["VisibilityTimeout"] == 100 for entry in entries) + + @pytest.mark.asyncio + async def test_renew_visibility_failure_is_swallowed(self): + client = _FakeSqsClient() + client.change_message_visibility_batch = AsyncMock(side_effect=RuntimeError("boom")) + producer = SqsSharedStreamProducer(sqs_queue="q") + producer._client = client + producer._outstanding = {"rh1": _payload("m1")} + + await producer._renew_visibility() # no raise + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.shared_stream.asyncio.sleep", new_callable=AsyncMock) + async def test_renew_visibility_loop_renews_each_interval(self, mock_sleep): + mock_sleep.side_effect = [None, asyncio.CancelledError] + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q", visibility_timeout=10) + producer._client = client + producer._outstanding = {"rh1": _payload("m1")} + + with pytest.raises(asyncio.CancelledError): + await producer._renew_visibility_loop() + + # interval is half the visibility timeout, and one renewal ran before cancel. + mock_sleep.assert_awaited_with(5) + assert len(client.visibility_entries) == 1 + + @pytest.mark.asyncio + async def test_aclose_cancels_renewer_and_closes_client_once(self): + client = _FakeSqsClient() + producer = SqsSharedStreamProducer(sqs_queue="q") + stack = AsyncExitStack() + producer._exit_stack = stack + producer._client = await stack.enter_async_context(client) + producer._renew_task = asyncio.create_task(asyncio.sleep(3600)) + producer._outstanding = {"rh1": _payload("m1")} + + await producer.aclose() + + assert client.closed is True + assert producer._client is None + assert producer._renew_task is None + assert producer._outstanding == {} + # Second close is a no-op. + await producer.aclose() + + +class TestSqsSharedStreamTrigger: + def test_serialize_roundtrips_classpath_and_kwargs(self): + trigger = SqsSharedStreamTrigger( + sqs_queue="q", + aws_conn_id="c", + max_messages=5, + wait_time_seconds=10, + visibility_timeout=30, + region_name="us-east-1", + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.shared_stream.SqsSharedStreamTrigger" + assert kwargs == { + "sqs_queue": "q", + "aws_conn_id": "c", + "max_messages": 5, + "wait_time_seconds": 10, + "visibility_timeout": 30, + "region_name": "us-east-1", + "verify": None, + "botocore_config": None, + } + + def test_shared_stream_key_shares_on_queue_and_connection(self): + # Poll-tuning fields (max_messages, ...) do not change which queue is watched. + a = SqsSharedStreamTrigger(sqs_queue="q", aws_conn_id="c", max_messages=1) + b = SqsSharedStreamTrigger(sqs_queue="q", aws_conn_id="c", max_messages=99) + assert a.shared_stream_key() == b.shared_stream_key() + # Must be usable as a dict key. + assert {a.shared_stream_key(): 1} + + def test_shared_stream_key_differs_by_queue(self): + a = SqsSharedStreamTrigger(sqs_queue="q1") + b = SqsSharedStreamTrigger(sqs_queue="q2") + assert a.shared_stream_key() != b.shared_stream_key() + + def test_create_shared_stream_producer_builds_producer_from_kwargs(self): + trigger = SqsSharedStreamTrigger( + sqs_queue="q", aws_conn_id="c", max_messages=5, wait_time_seconds=10, visibility_timeout=30 + ) + producer = trigger.create_shared_stream_producer(trigger.serialize()[1]) + assert isinstance(producer, SqsSharedStreamProducer) + assert producer.sqs_queue == "q" + assert producer.aws_conn_id == "c" + assert producer.max_messages == 5 + assert producer.wait_time_seconds == 10 + assert producer.visibility_timeout == 30 + + @pytest.mark.asyncio + async def test_filter_shared_stream_yields_one_event_per_body(self): + trigger = SqsSharedStreamTrigger(sqs_queue="q") + + async def raw_stream(): + yield "first" + yield "second" + + events = [event async for event in trigger.filter_shared_stream(raw_stream())] + assert [type(e) for e in events] == [TriggerEvent, TriggerEvent] + assert [e.payload for e in events] == ["first", "second"] + + @pytest.mark.asyncio + async def test_run_is_not_supported_standalone(self): + trigger = SqsSharedStreamTrigger(sqs_queue="q") + with pytest.raises(NotImplementedError, match="shared-stream manager"): + await anext(trigger.run())