From bff081b1b759cab567e2a1dc9616af58b4e64d2d Mon Sep 17 00:00:00 2001 From: Rocky Jaiswal Date: Thu, 4 Jun 2026 13:15:18 +0200 Subject: [PATCH 1/2] fix(transport): harden Redis SQS-style queue (load balancing, bounded growth) + repair FastStream 0.7 filtering --- sdk/CHANGELOG.md | 35 +++ sdk/eggai/transport/inmemory.py | 10 + sdk/eggai/transport/kafka.py | 54 ++-- sdk/eggai/transport/middleware_utils.py | 152 +++++----- sdk/eggai/transport/pending_reclaimer.py | 23 +- sdk/eggai/transport/redis.py | 164 +++++++---- sdk/tests/test_middleware_utils.py | 128 ++++++++ sdk/tests/test_redis.py | 355 +++++++++++++++++++++++ 8 files changed, 749 insertions(+), 172 deletions(-) create mode 100644 sdk/tests/test_middleware_utils.py diff --git a/sdk/CHANGELOG.md b/sdk/CHANGELOG.md index 68b4d615..dce1ac19 100644 --- a/sdk/CHANGELOG.md +++ b/sdk/CHANGELOG.md @@ -7,6 +7,41 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed +- **Redis & Kafka transports**: `filter_by_message`, `data_type`, and + `filter_by_data` subscriptions raised `TypeError` at subscribe time under + FastStream 0.7, which removed publisher/subscriber-level middlewares. Filtering + and typed-message handling are now applied in EggAI's own handler wrapper + (application code) rather than via FastStream subscriber middlewares — aligning + with FastStream 0.7's removal of subscriber/publisher middlewares. This keeps the + per-subscription filter logic independent of FastStream's middleware API (it is + the same approach the in-memory transport already uses). As part of this, + `data_type` subscriptions on Redis/Kafka now deliver the **typed model instance** + to the handler (matching the in-memory transport and the documented behaviour), + rather than the raw dict. Supplying both `data_type` and `filter_by_message` + (mutually exclusive — use `filter_by_data` for typed filtering) now raises + `ValueError` instead of silently dropping the raw-dict predicate. + +### Added +- **RedisTransport**: New `max_len` and `retry_max_len` constructor options to cap + Redis stream growth via approximate trimming (`XADD ... MAXLEN ~`). `max_len` + (default `None`/unbounded) caps the producer/`publish()` path; it is opt-in + because `MAXLEN` trims the oldest entries by count regardless of ack state, so a + value below `throughput × consumer-lag` can drop un-delivered messages. + `retry_max_len` (default `10_000`) caps the SDK-managed retry and DLQ streams, + bounding the blast radius of a runaway retry loop. This wires up the previously + documented-but-inert `max_len` knob. + +### Changed +- **RedisTransport**: The stream consumer name now defaults to a per-process-unique + value (`{handler_id}-{hostname}-{pid}`) while the consumer **group** still defaults + to the stable `handler_id`. A fleet of workers running the same handler now shares + one group (Redis load-balances the stream across them) while each worker owns a + distinct slice of the pending-entries list — the competing-consumers pattern. The + auto-created retry-stream subscriber gets the same per-process-unique consumer, so + retried messages load-balance across a worker fleet too. Pass an explicit + `consumer=` to opt out. + ## [0.3.0] - 2026-06-03 ### Security diff --git a/sdk/eggai/transport/inmemory.py b/sdk/eggai/transport/inmemory.py index a6a01026..c70729ea 100644 --- a/sdk/eggai/transport/inmemory.py +++ b/sdk/eggai/transport/inmemory.py @@ -124,6 +124,16 @@ async def subscribe( final_callback = callback + # filter_by_message (raw-dict filter) and data_type (typed validation + + # filter_by_data) are mutually exclusive — reject the combination rather + # than silently ignoring one, matching the Kafka/Redis transports. + if "data_type" in kwargs and kwargs.get("filter_by_message") is not None: + raise ValueError( + "filter_by_message cannot be combined with data_type. Use " + "filter_by_data to filter typed subscriptions, or filter_by_message " + "on its own for raw-dict filtering." + ) + # Handle data_type filtering if "data_type" in kwargs: data_type = kwargs["data_type"] diff --git a/sdk/eggai/transport/kafka.py b/sdk/eggai/transport/kafka.py index 07b7846c..9842ab1f 100644 --- a/sdk/eggai/transport/kafka.py +++ b/sdk/eggai/transport/kafka.py @@ -8,11 +8,7 @@ from eggai.schemas import BaseMessage from eggai.transport.base import Transport -from eggai.transport.middleware_utils import ( - create_data_type_middleware, - create_filter_by_data_middleware, - create_filter_middleware, -) +from eggai.transport.middleware_utils import wrap_handler_with_filters class KafkaTransport(Transport): @@ -181,7 +177,14 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: **kwargs: Additional keyword arguments that can be used to configure the subscription. Keyword Args: - filter_by_message (Callable, optional): A function to filter incoming messages based on their payload. + filter_by_message (Callable, optional): Predicate applied to the decoded message dict; the handler + is invoked (with the dict) only when it returns truthy. + data_type (BaseModel, optional): A Pydantic model class used to validate and type incoming + messages. Messages that fail validation, or whose ``type`` field does not match the model's + default ``type``, are skipped. Matching messages are passed to the handler as the **typed + model instance** (not the raw dict). + filter_by_data (Callable, optional): Predicate applied to the validated typed message (requires + `data_type`); the handler runs only when it returns truthy. group_id (Optional[str], optional): The consumer group name for dynamic partition assignment. auto_offset_reset (str, optional): Policy for resetting offsets ('earliest' or 'latest'). @@ -190,30 +193,21 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: """ from eggai.tracing import make_tracing_wrapper - handler = make_tracing_wrapper(channel, handler) - - if "filter_by_message" in kwargs: - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append( - create_filter_middleware(kwargs.pop("filter_by_message")) - ) - - if "data_type" in kwargs: - data_type = kwargs.pop("data_type") - - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append(create_data_type_middleware(data_type)) - - if "filter_by_data" in kwargs: - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append( - create_filter_by_data_middleware( - data_type, kwargs.pop("filter_by_data") - ) - ) + # EggAI applies content filtering (filter_by_message) and typed-subscription + # support (data_type / filter_by_data) by wrapping the handler — see + # wrap_handler_with_filters — NOT via FastStream subscriber middlewares, + # which FastStream 0.7 removed. Tracing stays OUTERMOST so traceparent is + # read from the raw decoded dict before a data_type subscription validates + # it into a typed model: tracing( filters( handler ) ). + handler = make_tracing_wrapper( + channel, + wrap_handler_with_filters( + handler, + filter_by_message=kwargs.pop("filter_by_message", None), + data_type=kwargs.pop("data_type", None), + filter_by_data=kwargs.pop("filter_by_data", None), + ), + ) # Use handler_id as default group_id (preserves broadcast behavior) handler_id = kwargs.pop("handler_id", None) diff --git a/sdk/eggai/transport/middleware_utils.py b/sdk/eggai/transport/middleware_utils.py index 14a134ff..ae44e572 100644 --- a/sdk/eggai/transport/middleware_utils.py +++ b/sdk/eggai/transport/middleware_utils.py @@ -1,93 +1,81 @@ """ -Shared middleware utilities for transport implementations. - -This module provides common middleware factories used by Kafka and Redis transports -to handle message filtering and data type validation. +Shared message-filtering utilities for transport implementations. + +Kafka and Redis both apply EggAI's content filtering (``filter_by_message``) and +typed-subscription support (``data_type`` / ``filter_by_data``) by wrapping the +handler, *not* via broker subscriber middlewares. FastStream 0.7 removed the +``middlewares`` argument from ``subscriber()`` (and from the broker constructor), +so the old middleware-based approach raised ``TypeError`` at subscribe time. This +handler-wrapping approach is independent of FastStream's middleware API and keeps +the behaviour identical across the Kafka, Redis, and in-memory transports. """ -import json -from collections.abc import Awaitable, Callable +from collections.abc import Callable from typing import Any -from faststream.message.message import StreamMessage - - -def create_filter_middleware(filter_func: Callable[[dict[str, Any]], bool]) -> Callable: - """ - Create a middleware that filters messages based on a predicate function. - - Args: - filter_func: A function that takes a message dict and returns True if the message - should be processed, False otherwise. - - Returns: - A middleware function that applies the filter. - """ - - async def middleware( - call_next: Callable[[Any], Awaitable[Any]], - msg: StreamMessage[Any], - ) -> Any: - if filter_func(json.loads(msg.body.decode("utf-8"))): - return await call_next(msg) - return None - - return middleware - - -def create_data_type_middleware(data_type: type) -> Callable: - """ - Create a middleware that validates and filters messages by data type. - - Args: - data_type: A Pydantic model class with a 'type' field that will be used - for validation and filtering. - - Returns: - A middleware function that validates the message against the data type - and filters out messages that don't match the expected type. - """ - - async def middleware( - call_next: Callable[[Any], Awaitable[Any]], - msg: StreamMessage[Any], - ) -> Any: - typed_message = data_type.model_validate(json.loads(msg.body.decode("utf-8"))) - - if typed_message.type != data_type.model_fields["type"].default: - return None - - return await call_next(msg) - - return middleware +from pydantic import ValidationError -def create_filter_by_data_middleware( - data_type: type, filter_func: Callable[[Any], bool] +def wrap_handler_with_filters( + handler: Callable, + *, + data_type: type | None = None, + filter_by_data: Callable[[Any], bool] | None = None, + filter_by_message: Callable[[dict[str, Any]], bool] | None = None, ) -> Callable: + """Wrap ``handler`` with EggAI's content filtering / typed-message support. + + The returned coroutine receives the broker-decoded message (a ``dict``) and: + + - ``data_type``: validates the dict against the Pydantic model. Messages that + fail validation, or whose ``type`` field does not match the model's default + ``type``, are skipped. Matching messages are passed to ``handler`` as the + **typed model instance** (e.g. ``OrderMessage``), not the raw dict. + - ``data_type`` + ``filter_by_data``: as above, and additionally skipped unless + ``filter_by_data(typed_message)`` returns truthy. + - ``filter_by_message`` (no ``data_type``): ``handler`` is called with the raw + dict only when ``filter_by_message(dict)`` returns truthy. + + Skipped messages return ``None`` without invoking ``handler`` — a clean no-op, + so the broker acknowledges them (they are not retried). When no filtering + option is supplied, ``handler`` is returned unchanged. + + ``filter_by_message`` and ``data_type`` are mutually exclusive: the former is + the untyped (raw-dict) filter, the latter validates into a typed model and + pairs with ``filter_by_data``. Supplying both is rejected rather than silently + dropping one of them. """ - Create a middleware that validates messages by data type and applies a filter. - - This combines data type validation with a custom filter function that operates - on the validated/typed message object. - - Args: - data_type: A Pydantic model class for validation. - filter_func: A function that takes the validated message object and returns - True if it should be processed, False otherwise. + if data_type is not None and filter_by_message is not None: + raise ValueError( + "filter_by_message cannot be combined with data_type. Use filter_by_data " + "(which receives the validated typed message) to filter typed " + "subscriptions, or filter_by_message on its own for raw-dict filtering." + ) + + if data_type is not None: + expected_type = data_type.model_fields["type"].default + + async def typed_handler(message: dict[str, Any]) -> Any: + try: + typed_message = data_type.model_validate(message) + except (ValidationError, ValueError, TypeError): + # Wrong shape / payload for this data_type — not ours to handle. + return None + if typed_message.type != expected_type: + return None + if filter_by_data is not None and not filter_by_data(typed_message): + return None + return await handler(typed_message) + + return typed_handler + + if filter_by_message is not None: + + async def filtered_handler(message: dict[str, Any]) -> Any: + if filter_by_message(message): + return await handler(message) + return None - Returns: - A middleware function that validates and filters messages. - """ + return filtered_handler - async def middleware( - call_next: Callable[[Any], Awaitable[Any]], - msg: StreamMessage[Any], - ) -> Any: - data = json.loads(msg.body.decode("utf-8")) - typed_message = data_type.model_validate(data) - if filter_func(typed_message): - return await call_next(msg) - return None - - return middleware + return handler diff --git a/sdk/eggai/transport/pending_reclaimer.py b/sdk/eggai/transport/pending_reclaimer.py index 034ec3ed..9f6e529c 100644 --- a/sdk/eggai/transport/pending_reclaimer.py +++ b/sdk/eggai/transport/pending_reclaimer.py @@ -53,6 +53,9 @@ class ReclaimerConfig: None # full key, e.g. "eggai.orders.order-service-handle_order-1.dlq" ) on_dlq: Callable | None = None # async or sync callback(fields, msg_id, count) + max_len: int | None = ( + None # cap retry/DLQ stream length (XADD MAXLEN ~); None = unbounded + ) def _inject_retry_metadata(data: bytes, msg_id_str: str) -> tuple[bytes, int, bool]: @@ -215,6 +218,20 @@ async def _ensure_group(self, config: ReclaimerConfig) -> None: if "BUSYGROUP" not in str(e): raise + async def _xadd(self, stream: str, fields: dict, max_len: int | None) -> None: + """XADD with optional approximate length capping. + + ``approximate=True`` (the ``MAXLEN ~`` form) lets Redis trim on whole-node + boundaries, which is far cheaper than exact trimming and is the + recommended production setting. ``max_len=None`` means no trimming. + """ + if max_len is not None: + await self._redis_client.xadd( + stream, fields, maxlen=max_len, approximate=True + ) + else: + await self._redis_client.xadd(stream, fields) + async def _reclaim_once(self, config: ReclaimerConfig) -> None: # --- Paginated XPENDING scan --- # A fixed count=100 only scans one page; with large or high-traffic PELs, @@ -273,7 +290,7 @@ async def _reclaim_once(self, config: ReclaimerConfig) -> None: # (XACK) with a loud warning rather than spin on it indefinitely. if not parsed_ok: if config.dlq_stream is not None: - await self._redis_client.xadd(config.dlq_stream, fields) + await self._xadd(config.dlq_stream, fields, config.max_len) await self._redis_client.xack(config.stream, config.group, msg_id) logger.warning( "Message %s has an unparseable envelope; moved to DLQ %s " @@ -299,7 +316,7 @@ async def _reclaim_once(self, config: ReclaimerConfig) -> None: and config.dlq_stream is not None and new_count > config.max_retries ): - await self._redis_client.xadd(config.dlq_stream, fields) + await self._xadd(config.dlq_stream, fields, config.max_len) await self._redis_client.xack(config.stream, config.group, msg_id) logger.warning( "Message %s exceeded max_retries=%d; moved to DLQ %s", @@ -311,7 +328,7 @@ async def _reclaim_once(self, config: ReclaimerConfig) -> None: config, fields, data_key, msg_id_str, new_count ) else: - await self._redis_client.xadd(config.retry_stream, fields) + await self._xadd(config.retry_stream, fields, config.max_len) await self._redis_client.xack(config.stream, config.group, msg_id) logger.debug("Reclaimed %s → %s", msg_id_str, config.retry_stream) diff --git a/sdk/eggai/transport/redis.py b/sdk/eggai/transport/redis.py index b6ceee2b..68847995 100644 --- a/sdk/eggai/transport/redis.py +++ b/sdk/eggai/transport/redis.py @@ -1,5 +1,7 @@ import asyncio import logging +import os +import socket from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -11,15 +13,17 @@ from eggai.schemas import BaseMessage from eggai.transport.base import Transport -from eggai.transport.middleware_utils import ( - create_data_type_middleware, - create_filter_by_data_middleware, - create_filter_middleware, -) +from eggai.transport.middleware_utils import wrap_handler_with_filters from eggai.transport.pending_reclaimer import PendingReclaimerManager, ReclaimerConfig logger = logging.getLogger(__name__) +# Stable for the lifetime of this process, unique across processes/hosts. Used as +# the default Redis stream consumer name so that multiple workers sharing a +# consumer group (same handler code → same group) each claim a distinct slice of +# the stream and own their own PEL entries — the competing-consumers pattern. +_CONSUMER_INSTANCE = f"{socket.gethostname()}-{os.getpid()}" + @dataclass(frozen=True) class _StreamGroupInfo: @@ -47,6 +51,8 @@ def __init__( broker: RedisBroker | None = None, url: str = "redis://localhost:6379", group_monitor_interval_s: float = 5.0, + max_len: int | None = None, + retry_max_len: int | None = 10_000, **kwargs, ): """ @@ -56,6 +62,18 @@ def __init__( broker (Optional[RedisBroker]): An existing RedisBroker instance to use. If not provided, a new instance will be created with the specified URL and additional parameters. url (str): The Redis connection URL (default is "redis://localhost:6379"). + group_monitor_interval_s (float): How often the background monitor re-asserts that every + registered consumer group exists (default 5.0s). + max_len (Optional[int]): Approximate cap on the length of streams written by ``publish()`` + (the producer path), applied as ``XADD ... MAXLEN ~ max_len``. Default ``None`` (unbounded). + NOTE: ``MAXLEN`` trims the *oldest* entries by count regardless of whether they have been + consumed/acked, so a value smaller than ``throughput × consumer-lag`` can silently drop + un-delivered messages. Left opt-in for that reason — set it deliberately per your retention + needs (e.g. 100_000) once you understand your traffic. + retry_max_len (Optional[int]): Approximate cap on the SDK-managed retry and DLQ streams + (default 10_000). These hold only reclaimed failures, so their volume is bounded by your + error rate and a default cap prevents a runaway retry loop from growing without bound. + Set to ``None`` to disable trimming on retry/DLQ streams. **kwargs: Additional keyword arguments to pass to the RedisBroker if a new instance is created. Attributes: @@ -114,6 +132,8 @@ def __init__( else: self.broker = RedisBroker(url, log_level=logging.INFO, **kwargs) self._redis_url = url + self._max_len = max_len + self._retry_max_len = retry_max_len self._running = False self._reclaimer_manager: PendingReclaimerManager | None = None # A set so repeated subscribe() calls with an identical (stream_key, group, @@ -195,7 +215,12 @@ async def publish(self, channel: str, message: dict[str, Any] | BaseMessage): before being sent. """ - await self.broker.publish(message, stream=channel) + # When max_len is configured, cap the stream approximately (XADD MAXLEN ~) + # so the producer path can't grow without bound. None → no trimming. + if self._max_len is not None: + await self.broker.publish(message, stream=channel, maxlen=self._max_len) + else: + await self.broker.publish(message, stream=channel) async def subscribe(self, channel: str, handler, **kwargs) -> Callable: """ @@ -207,11 +232,15 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: **kwargs: Additional keyword arguments that can be used to configure the subscription. Keyword Args: - filter_by_message (Callable, optional): A function to filter incoming messages based on their payload. If provided, - this function will be applied to the message payload before passing it to - the handler. - data_type (BaseModel, optional): A Pydantic model class to validate and filter incoming messages by type. - filter_by_data (Callable, optional): A function to filter typed messages after validation (requires `data_type`). + filter_by_message (Callable, optional): Predicate applied to the decoded message dict; the handler + is invoked (with the dict) only for messages where it returns truthy. Non-matching messages + are skipped (acked, not retried). + data_type (BaseModel, optional): A Pydantic model class used to validate and type incoming messages. + Messages that fail validation, or whose ``type`` field does not match the model's default + ``type``, are skipped. Matching messages are passed to the handler as the **typed model + instance** (not the raw dict). + filter_by_data (Callable, optional): Predicate applied to the validated typed message (requires + `data_type`); the handler runs only when it returns truthy. # Redis Pub/Sub parameters pattern (bool, optional): Whether to use pattern-based subscription (default is False). @@ -219,8 +248,12 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: # Redis Stream parameters stream (Optional[str], optional): Redis stream name to consume from instead of Pub/Sub channel. polling_interval (int, optional): Interval in milliseconds for polling streams (default is 100). - group (Optional[str], optional): Consumer group name for stream consumption. - consumer (Optional[str], optional): Consumer name within the group. + group (Optional[str], optional): Consumer group name for stream consumption. Defaults to the + handler_id, which is stable across workers running the same handler — so a fleet of workers + shares one group and load-balances the stream (competing consumers). + consumer (Optional[str], optional): Consumer name within the group. Defaults to a + per-process-unique name (``{handler_id}-{hostname}-{pid}``) so each worker in the group owns + a distinct slice of the PEL. Pass an explicit value only if you need a stable consumer name. batch (bool, optional): Whether to consume messages in batches (default is False). max_records (Optional[int], optional): Maximum number of records to consume in one batch (default is None). last_id (str, optional): Starting message ID for stream consumption (default is ">" for consumer groups). @@ -251,13 +284,13 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: retry_on_error (bool, optional): Whether to retry handler on error (default is True). # Durability parameters - max_len (Optional[int], optional): Maximum stream length to prevent unbounded growth (default is None). - Recommend setting to a reasonable value like 10000 for production. + # + # Stream length is capped at the transport level, not per-subscriber: pass + # ``max_len`` (producer/publish path) and/or ``retry_max_len`` (retry & DLQ + # streams, default 10_000) to ``RedisTransport(...)``. See the constructor docstring. # General parameters dependencies (Sequence[Depends], optional): Custom dependencies for this subscriber. - middlewares (Sequence[BrokerMiddleware], optional): Custom middlewares for this subscriber. - filter (Filter, optional): Message filter configuration. parser (Optional[CustomCallable], optional): Custom parser for this subscriber. decoder (Optional[CustomCallable], optional): Custom decoder for this subscriber. no_reply (bool, optional): Whether to disable message acknowledgment (default is False). @@ -274,52 +307,57 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: on_dlq = kwargs.pop("on_dlq", None) _internal_retry = kwargs.pop("_internal_retry", False) - # The tracing wrapper binds the channel name into each span's - # messaging.destination, so we wrap per-stream: the main subscriber is - # wrapped with `channel` here, and (in the retry block below) the retry - # subscriber is wrapped with the retry stream key — otherwise retry-attempt - # spans would report the original channel and could not be dashboarded - # separately. We keep the unwrapped `original_handler` so the retry wrapper - # wraps it directly rather than nesting a second span on the channel - # wrapper. The recursive retry subscribe is marked _internal_retry=True and + # EggAI applies content filtering (filter_by_message) and typed-subscription + # support (data_type / filter_by_data) by wrapping the handler — see + # wrap_handler_with_filters — NOT via FastStream subscriber middlewares, + # which FastStream 0.7 removed from subscriber()/the broker constructor. + # + # Wrapper order matters: tracing must stay OUTERMOST so traceparent is read + # from the raw decoded dict before a data_type subscription validates it + # into a typed model (the typed model may not carry the traceparent field). + # So the stack is tracing( filters( handler ) ). + # + # The tracing wrapper also binds the channel name into each span's + # messaging.destination, so we wrap per-stream: the main subscriber with + # `channel` here and (in the retry block below) the retry subscriber with the + # retry stream key — otherwise retry-attempt spans would report the original + # channel and could not be dashboarded separately. We keep the unwrapped + # `original_handler` so the retry block can rebuild the same filter+trace + # stack. The recursive retry subscribe is marked _internal_retry=True and # receives an already-wrapped handler, so it must not wrap again. + filter_opts = { + "filter_by_message": kwargs.pop("filter_by_message", None), + "data_type": kwargs.pop("data_type", None), + "filter_by_data": kwargs.pop("filter_by_data", None), + } original_handler = handler if not _internal_retry: from eggai.tracing import make_tracing_wrapper - handler = make_tracing_wrapper(channel, handler) - - if "filter_by_message" in kwargs: - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append( - create_filter_middleware(kwargs.pop("filter_by_message")) + handler = make_tracing_wrapper( + channel, wrap_handler_with_filters(handler, **filter_opts) ) - if "data_type" in kwargs: - data_type = kwargs.pop("data_type") - - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append(create_data_type_middleware(data_type)) - - if "filter_by_data" in kwargs: - if "middlewares" not in kwargs: - kwargs["middlewares"] = [] - kwargs["middlewares"].append( - create_filter_by_data_middleware( - data_type, kwargs.pop("filter_by_data") - ) - ) - handler_id = kwargs.pop("handler_id", None) # Ignore Kafka-specific parameter (Redis uses 'group' for streams, not 'group_id') kwargs.pop("group_id", None) - # Extract stream-related parameters + # Extract stream-related parameters. + # `group` defaults to handler_id so that multiple workers running the same + # handler code share one consumer group (Redis distributes new messages + # across the group — competing-consumers load balancing). `consumer` must + # be distinct *per worker* within that group, otherwise two processes share + # one consumer name and their PEL entries become indistinguishable. So when + # the caller doesn't pin a consumer, default it to a per-process-unique name + # derived from handler_id. Direct callers passing consumer= keep full + # control; handler_id=None (no group) leaves consumer None as before. group = kwargs.pop("group", handler_id) - consumer = kwargs.pop("consumer", handler_id) + consumer = kwargs.pop("consumer", None) + if consumer is None: + consumer = ( + f"{handler_id}-{_CONSUMER_INSTANCE}" if handler_id else handler_id + ) polling_interval = kwargs.pop("polling_interval", 100) batch = kwargs.pop("batch", False) max_records = kwargs.pop("max_records", None) @@ -457,14 +495,20 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: ) ) - # Wrap the *original* (unwrapped) handler with the retry stream as - # its tracing destination, so retry-attempt spans report the retry - # stream key rather than the original channel. Wrapping the original - # (not the channel-wrapped handler) keeps it to a single consumer - # span per retry. + # Rebuild the same filter+trace stack against the retry stream: + # tracing( filters( original_handler ) ), with the retry stream as the + # tracing destination so retry-attempt spans report the retry stream + # key rather than the original channel. Reclaimed messages re-enter as + # raw dicts, so they must pass back through wrap_handler_with_filters + # to be typed/filtered exactly like a first delivery. Wrapping the + # original (not the channel-wrapped handler) keeps it to a single + # consumer span per retry. from eggai.tracing import make_tracing_wrapper - retry_handler = make_tracing_wrapper(retry_stream, original_handler) + retry_handler = make_tracing_wrapper( + retry_stream, + wrap_handler_with_filters(original_handler, **filter_opts), + ) # Auto-subscribe the same handler to the retry stream. # _internal_retry=True prevents infinite recursion and .retry.retry @@ -472,13 +516,18 @@ async def subscribe(self, channel: str, handler, **kwargs) -> Callable: # forwarding the caller's last_id: an operator replaying the main # stream (e.g. last_id="0") must not also replay the retry stream's # history — the retry stream only ever holds freshly-reclaimed entries. + # + # `group` is pinned to retry_handler_id (stable across workers) but + # `consumer` is deliberately omitted so it defaults to the + # per-process-unique name — otherwise a worker fleet would conflate + # the retry stream's PEL under one consumer name, undercutting the + # competing-consumers behaviour for retried messages. await self.subscribe( retry_stream, retry_handler, _internal_retry=True, handler_id=retry_handler_id, group=retry_handler_id, - consumer=retry_handler_id, polling_interval=polling_interval, batch=batch, max_records=max_records, @@ -590,6 +639,7 @@ def _setup_reclaimer( max_retries=max_retries, dlq_stream=self._get_stream_key(dlq_stream) if dlq_stream else None, on_dlq=on_dlq, + max_len=self._retry_max_len, ) ) diff --git a/sdk/tests/test_middleware_utils.py b/sdk/tests/test_middleware_utils.py new file mode 100644 index 00000000..78f181d8 --- /dev/null +++ b/sdk/tests/test_middleware_utils.py @@ -0,0 +1,128 @@ +"""Unit tests for the shared handler-filtering wrapper. + +These are intentionally broker-independent so they run in CI even when no Redis or +Kafka service is available (the integration suites are auto-skipped without one). +They lock in the contract that both the Redis and Kafka transports rely on after +FastStream 0.7 removed subscriber middlewares. +""" + +import pytest +from pydantic import BaseModel + +from eggai.schemas import BaseMessage +from eggai.transport.middleware_utils import wrap_handler_with_filters + + +class Order(BaseModel): + order_id: int + status: str + + +class OrderMessage(BaseMessage[Order]): + type: str = "OrderMessage" + + +def _order_msg(order_id=1, status="new", **overrides): + payload = OrderMessage(source="t", data=Order(order_id=order_id, status=status)) + # Brokers hand the handler a decoded dict, so feed the wrapper a dict. + return {**payload.model_dump(mode="json"), **overrides} + + +@pytest.mark.asyncio +async def test_no_options_returns_handler_unchanged(): + async def handler(m): + return m + + assert wrap_handler_with_filters(handler) is handler + + +@pytest.mark.asyncio +async def test_filter_by_message_passes_dict_on_match(): + seen = [] + + async def handler(m): + seen.append(m) + + wrapped = wrap_handler_with_filters( + handler, filter_by_message=lambda m: m["type"] == "keep" + ) + await wrapped({"type": "keep", "v": 1}) + await wrapped({"type": "drop", "v": 2}) + + assert seen == [{"type": "keep", "v": 1}] + + +@pytest.mark.asyncio +async def test_data_type_delivers_typed_instance(): + seen = [] + + async def handler(order): + seen.append(order) + + wrapped = wrap_handler_with_filters(handler, data_type=OrderMessage) + await wrapped(_order_msg(order_id=42, status="new")) + + assert len(seen) == 1 + assert isinstance(seen[0], OrderMessage) # typed, not dict + assert seen[0].data.order_id == 42 + + +@pytest.mark.asyncio +async def test_data_type_skips_wrong_type(): + seen = [] + + async def handler(order): + seen.append(order) + + wrapped = wrap_handler_with_filters(handler, data_type=OrderMessage) + # Right shape, wrong discriminator. + await wrapped(_order_msg(type="SomethingElse")) + + assert seen == [] + + +@pytest.mark.asyncio +async def test_data_type_skips_invalid_payload(): + seen = [] + + async def handler(order): + seen.append(order) + + wrapped = wrap_handler_with_filters(handler, data_type=OrderMessage) + # Correct discriminator, invalid payload → ValidationError → skipped, no raise. + await wrapped({"type": "OrderMessage", "source": "t", "data": {"bad": True}}) + + assert seen == [] + + +def test_data_type_with_filter_by_message_is_rejected(): + """Combining data_type with filter_by_message is a footgun (one would be + silently dropped), so it must raise rather than quietly ignore the predicate.""" + + async def handler(m): + return m + + with pytest.raises(ValueError, match="cannot be combined with data_type"): + wrap_handler_with_filters( + handler, + data_type=OrderMessage, + filter_by_message=lambda m: True, + ) + + +@pytest.mark.asyncio +async def test_filter_by_data_narrows_typed_messages(): + seen = [] + + async def handler(order): + seen.append(order.data.order_id) + + wrapped = wrap_handler_with_filters( + handler, + data_type=OrderMessage, + filter_by_data=lambda o: o.data.status == "shipped", + ) + await wrapped(_order_msg(order_id=1, status="new")) + await wrapped(_order_msg(order_id=2, status="shipped")) + + assert seen == [2] diff --git a/sdk/tests/test_redis.py b/sdk/tests/test_redis.py index 74d46949..12113c29 100644 --- a/sdk/tests/test_redis.py +++ b/sdk/tests/test_redis.py @@ -15,8 +15,10 @@ import pytest import redis.asyncio as redis from faststream import AckPolicy +from pydantic import BaseModel from eggai import Agent, Channel +from eggai.schemas import BaseMessage from eggai.transport import RedisTransport, eggai_set_default_transport @@ -1268,3 +1270,356 @@ async def handler(message): "Message published during partial group loss was not redelivered. " "The group was likely recreated with id='$' instead of id='0'." ) + + +# --- Consumer naming (load balancing) --------------------------------------- + + +@pytest.mark.asyncio +async def test_consumer_defaults_to_unique_per_process(monkeypatch): + """The stream consumer name defaults to a per-process-unique value while the + group stays = handler_id. + + This is the competing-consumers prerequisite: a fleet of workers running the + same handler shares one group (so Redis load-balances the stream) but each + owns a distinct consumer name (so PEL entries are not conflated). + """ + import eggai.transport.redis as redis_mod + + captured = {} + real_stream_sub = redis_mod.StreamSub + + def spy_stream_sub(channel, **kwargs): + captured["group"] = kwargs.get("group") + captured["consumer"] = kwargs.get("consumer") + return real_stream_sub(channel, **kwargs) + + monkeypatch.setattr(redis_mod, "StreamSub", spy_stream_sub) + + transport = RedisTransport() + + async def handler(message): + return message + + await transport.subscribe("orders", handler, handler_id="orders-handler-1") + + assert captured["group"] == "orders-handler-1" + assert captured["consumer"] == f"orders-handler-1-{redis_mod._CONSUMER_INSTANCE}" + # group is stable across workers; consumer carries the per-process suffix. + assert captured["consumer"] != captured["group"] + + +@pytest.mark.asyncio +async def test_retry_subscriber_uses_unique_consumer(monkeypatch): + """The auto-created retry-stream subscriber must also get a per-process-unique + consumer name (group stays stable) — otherwise a worker fleet conflates the + retry stream's PEL under one consumer name, undercutting load balancing for + retried messages.""" + import eggai.transport.redis as redis_mod + + captured = [] + real_stream_sub = redis_mod.StreamSub + + def spy_stream_sub(channel, **kwargs): + captured.append((channel, kwargs.get("group"), kwargs.get("consumer"))) + return real_stream_sub(channel, **kwargs) + + monkeypatch.setattr(redis_mod, "StreamSub", spy_stream_sub) + + transport = RedisTransport() + + async def handler(message): + return message + + await transport.subscribe( + "orders", handler, handler_id="orders-handler-1", retry_on_idle_ms=500 + ) + + retry_handler_id = "orders-handler-1-retry" + retry_subs = [c for c in captured if c[0] == "orders.orders-handler-1.retry"] + assert retry_subs, f"no retry-stream subscriber registered; saw {captured}" + _, group, consumer = retry_subs[0] + assert group == retry_handler_id # group stable across workers + assert consumer == f"{retry_handler_id}-{redis_mod._CONSUMER_INSTANCE}" + assert consumer != retry_handler_id # not the shared/static name + + +@pytest.mark.asyncio +async def test_explicit_consumer_is_respected(monkeypatch): + """An explicit consumer= is passed through unchanged (no per-process suffix).""" + import eggai.transport.redis as redis_mod + + captured = {} + real_stream_sub = redis_mod.StreamSub + + def spy_stream_sub(channel, **kwargs): + captured.update(group=kwargs.get("group"), consumer=kwargs.get("consumer")) + return real_stream_sub(channel, **kwargs) + + monkeypatch.setattr(redis_mod, "StreamSub", spy_stream_sub) + + transport = RedisTransport() + + async def handler(message): + return message + + await transport.subscribe( + "orders", + handler, + handler_id="orders-handler-1", + group="shared-group", + consumer="fixed-consumer", + ) + + assert captured["group"] == "shared-group" + assert captured["consumer"] == "fixed-consumer" + + +# --- max_len / stream length capping ---------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_passes_maxlen_when_configured(monkeypatch): + """publish() forwards maxlen to the broker only when max_len is configured.""" + transport = RedisTransport(max_len=1000) + + calls = [] + + async def fake_publish(message, **kwargs): + calls.append(kwargs) + + monkeypatch.setattr(transport.broker, "publish", fake_publish) + + await transport.publish("orders", {"x": 1}) + + assert calls == [{"stream": "orders", "maxlen": 1000}] + + +@pytest.mark.asyncio +async def test_publish_omits_maxlen_by_default(monkeypatch): + """With the default max_len=None, publish() must not pass maxlen (no silent + trimming of the producer's stream).""" + transport = RedisTransport() # max_len defaults to None + + calls = [] + + async def fake_publish(message, **kwargs): + calls.append(kwargs) + + monkeypatch.setattr(transport.broker, "publish", fake_publish) + + await transport.publish("orders", {"x": 1}) + + assert calls == [{"stream": "orders"}] + assert "maxlen" not in calls[0] + + +@pytest.mark.asyncio +async def test_retry_max_len_propagates_to_reclaimer_configs(): + """retry_max_len caps the retry/DLQ streams; default is 10_000, overridable.""" + transport = RedisTransport() # retry_max_len defaults to 10_000 + + async def handler(message): + return message + + await transport.subscribe( + "orders", handler, handler_id="orders-handler-1", retry_on_idle_ms=500 + ) + + assert transport._reclaimer_manager is not None + configs = list(transport._reclaimer_manager._configs.values()) + assert configs, "expected reclaimer configs to be registered" + assert all(config.max_len == 10_000 for config in configs) + + # Custom value flows through. + transport2 = RedisTransport(retry_max_len=250) + await transport2.subscribe( + "orders", handler, handler_id="orders-handler-2", retry_on_idle_ms=500 + ) + configs2 = list(transport2._reclaimer_manager._configs.values()) + assert all(config.max_len == 250 for config in configs2) + + +@pytest.mark.asyncio +async def test_publish_maxlen_bounds_stream_length(): + """Integration: publishing well past max_len keeps the stream approximately + bounded (XADD MAXLEN ~). Approximate trimming works on whole-node boundaries, + so we publish far more than the cap and assert it stays within a generous + multiple rather than an exact count.""" + redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True) + + test_id = uuid.uuid4().hex[:8] + channel_name = f"test-maxlen-{test_id}" + stream_name = f"eggai.{channel_name}" + + max_len = 100 + transport = RedisTransport(max_len=max_len) + channel = Channel(channel_name, transport=transport) + + # Publish 20x the cap so trimming definitely crosses node boundaries. + for i in range(max_len * 20): + await channel.publish({"type": "test", "n": i, "test_id": test_id}) + + length = await redis_client.xlen(stream_name) + + # Approximate trimming never trims below the cap and only slightly above it; + # without trimming the length would be 2000. Assert it is meaningfully bounded. + await redis_client.delete(stream_name) + await redis_client.aclose() + await transport.disconnect() + + assert length < max_len * 5, ( + f"Stream length {length} not bounded by max_len~{max_len} " + f"(expected well under {max_len * 5}, unbounded would be {max_len * 20})" + ) + + +# --- Filtering / typed subscriptions (FastStream 0.7 middleware removal) ------ +# +# These prove that filter_by_message / data_type / filter_by_data work on the +# Redis transport without FastStream subscriber middlewares (removed in 0.7), and +# that data_type delivers the *typed model instance* to the handler — the same +# contract the InMemory transport and test_typed_subscribe.py already rely on. + + +class _Order(BaseModel): + order_id: int + status: str + + +class _OrderMessage(BaseMessage[_Order]): + type: str = "_OrderMessage" + + +class _PaymentMessage(BaseMessage[_Order]): + type: str = "_PaymentMessage" + + +@pytest.mark.asyncio +async def test_redis_filter_by_message_routes_dicts(): + """filter_by_message invokes the handler (with the dict) only for matches.""" + test_id = uuid.uuid4().hex[:8] + transport = RedisTransport() + agent = Agent(f"filter-agent-{test_id}", transport=transport) + channel = Channel(f"test-filter-msg-{test_id}", transport=transport) + + seen = [] + got = asyncio.Event() + + @agent.subscribe( + channel=channel, + filter_by_message=lambda m: m.get("type") == "wanted", + ) + async def handler(message): + seen.append(message["type"]) + got.set() + + await agent.start() + await channel.publish({"type": "unwanted", "v": 1}) + await channel.publish({"type": "wanted", "v": 2}) + + await asyncio.wait_for(got.wait(), timeout=5.0) + await asyncio.sleep(0.3) # give the unwanted message a chance to (not) arrive + await agent.stop() + + assert seen == ["wanted"], f"filter_by_message let through: {seen}" + + +@pytest.mark.asyncio +async def test_redis_data_type_delivers_typed_instance(): + """data_type passes the validated *typed model* (not a dict) to the handler, + and skips messages whose type doesn't match.""" + test_id = uuid.uuid4().hex[:8] + transport = RedisTransport() + agent = Agent(f"typed-agent-{test_id}", transport=transport) + channel = Channel(f"test-typed-{test_id}", transport=transport) + + received = [] + got = asyncio.Event() + + @agent.subscribe(channel=channel, data_type=_OrderMessage) + async def handle_order(order: _OrderMessage): + received.append(order) + got.set() + + await agent.start() + + # Wrong type → skipped. + await channel.publish( + _PaymentMessage(source="t", data=_Order(order_id=1, status="x")) + ) + # Right type → delivered as a typed instance. + await channel.publish( + _OrderMessage(source="t", data=_Order(order_id=42, status="new")) + ) + + await asyncio.wait_for(got.wait(), timeout=5.0) + await asyncio.sleep(0.3) + await agent.stop() + + assert len(received) == 1, f"expected 1 matching message, got {len(received)}" + order = received[0] + # The headline contract: a typed model instance, not a dict. + assert isinstance(order, _OrderMessage), f"handler got {type(order)!r}, not typed" + assert isinstance(order.data, _Order) + assert order.data.order_id == 42 + assert order.data.status == "new" + + +@pytest.mark.asyncio +async def test_redis_filter_by_data_on_typed_message(): + """filter_by_data narrows typed messages by their validated content.""" + test_id = uuid.uuid4().hex[:8] + transport = RedisTransport() + agent = Agent(f"fbd-agent-{test_id}", transport=transport) + channel = Channel(f"test-fbd-{test_id}", transport=transport) + + received = [] + got = asyncio.Event() + + @agent.subscribe( + channel=channel, + data_type=_OrderMessage, + filter_by_data=lambda o: o.data.status == "shipped", + ) + async def handle_shipped(order: _OrderMessage): + received.append(order.data.order_id) + got.set() + + await agent.start() + await channel.publish( + _OrderMessage(source="t", data=_Order(order_id=1, status="new")) + ) + await channel.publish( + _OrderMessage(source="t", data=_Order(order_id=2, status="shipped")) + ) + + await asyncio.wait_for(got.wait(), timeout=5.0) + await asyncio.sleep(0.3) + await agent.stop() + + assert received == [2], f"filter_by_data let through: {received}" + + +@pytest.mark.asyncio +async def test_redis_invalid_payload_is_skipped_not_retried(): + """A message of the right type but invalid payload is skipped silently + (validation fails → no handler call, no error/retry).""" + test_id = uuid.uuid4().hex[:8] + transport = RedisTransport() + agent = Agent(f"invalid-agent-{test_id}", transport=transport) + channel = Channel(f"test-invalid-{test_id}", transport=transport) + + calls = [] + + @agent.subscribe(channel=channel, data_type=_OrderMessage) + async def handler(order: _OrderMessage): + calls.append(order) + + await agent.start() + # Correct type, but data is missing required fields → ValidationError → skip. + await channel.publish({"type": "_OrderMessage", "source": "t", "data": {"bad": 1}}) + await asyncio.sleep(0.6) + await agent.stop() + + assert calls == [], "invalid-payload message should not reach the handler" From f20b00d5342bda42f7ab3b1a85900db7f7ea408f Mon Sep 17 00:00:00 2001 From: Rocky Jaiswal Date: Thu, 4 Jun 2026 15:03:59 +0200 Subject: [PATCH 2/2] =?UTF-8?q?fix(transport):=20address=20review=20?= =?UTF-8?q?=E2=80=94=20filter=20validation,=20sync=20handlers,=20in-memory?= =?UTF-8?q?=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sdk/CHANGELOG.md | 8 +- sdk/eggai/transport/inmemory.py | 72 +++---------- sdk/eggai/transport/middleware_utils.py | 56 ++++++++-- sdk/tests/test_inmemory_filters.py | 136 ++++++++++++++++++++++++ sdk/tests/test_middleware_utils.py | 70 ++++++++++++ 5 files changed, 273 insertions(+), 69 deletions(-) create mode 100644 sdk/tests/test_inmemory_filters.py diff --git a/sdk/CHANGELOG.md b/sdk/CHANGELOG.md index dce1ac19..cfecf030 100644 --- a/sdk/CHANGELOG.md +++ b/sdk/CHANGELOG.md @@ -18,9 +18,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 the same approach the in-memory transport already uses). As part of this, `data_type` subscriptions on Redis/Kafka now deliver the **typed model instance** to the handler (matching the in-memory transport and the documented behaviour), - rather than the raw dict. Supplying both `data_type` and `filter_by_message` - (mutually exclusive — use `filter_by_data` for typed filtering) now raises - `ValueError` instead of silently dropping the raw-dict predicate. + rather than the raw dict. Invalid filter-option combinations now raise + `ValueError` instead of silently dropping an option: `data_type` and + `filter_by_message` are mutually exclusive, and `filter_by_data` requires + `data_type`. These validations are consistent across the Redis, Kafka, and + in-memory transports. ### Added - **RedisTransport**: New `max_len` and `retry_max_len` constructor options to cap diff --git a/sdk/eggai/transport/inmemory.py b/sdk/eggai/transport/inmemory.py index c70729ea..5d8e9ca4 100644 --- a/sdk/eggai/transport/inmemory.py +++ b/sdk/eggai/transport/inmemory.py @@ -8,6 +8,7 @@ from eggai.schemas import BaseMessage from eggai.transport import Transport +from eggai.transport.middleware_utils import wrap_handler_with_filters logger = logging.getLogger(__name__) @@ -122,66 +123,17 @@ async def subscribe( group_id = kwargs.get("group_id", handler_id or uuid.uuid4().hex) key = (channel, group_id) - final_callback = callback - - # filter_by_message (raw-dict filter) and data_type (typed validation + - # filter_by_data) are mutually exclusive — reject the combination rather - # than silently ignoring one, matching the Kafka/Redis transports. - if "data_type" in kwargs and kwargs.get("filter_by_message") is not None: - raise ValueError( - "filter_by_message cannot be combined with data_type. Use " - "filter_by_data to filter typed subscriptions, or filter_by_message " - "on its own for raw-dict filtering." - ) - - # Handle data_type filtering - if "data_type" in kwargs: - data_type = kwargs["data_type"] - - async def data_type_filtered_callback(data): - try: - typed_message = data_type.model_validate(data) - # Check if message type matches expected type - if typed_message.type != data_type.model_fields["type"].default: - return - # Pass the typed message object to the handler - await callback(typed_message) - except Exception: - # Skip messages that don't match the data type - return - - final_callback = data_type_filtered_callback - - # Handle filter_by_data if present along with data_type - if "filter_by_data" in kwargs: - filter_func = kwargs["filter_by_data"] - - async def data_and_filter_callback(data): - try: - typed_message = data_type.model_validate(data) - # Check if message type matches expected type - if typed_message.type != data_type.model_fields["type"].default: - return - # Apply the data filter - if filter_func(typed_message): - await callback(typed_message) - except (json.JSONDecodeError, ValueError, TypeError) as e: - # Skip messages that don't match the data type or filter - logger.debug(f"Message validation failed: {e}") - return - - final_callback = data_and_filter_callback - - # Handle legacy filter_by_message (for backward compatibility) - elif "filter_by_message" in kwargs: - filter_func = kwargs["filter_by_message"] - original_callback = final_callback # Store original before reassignment - - async def filtered_callback(data): - if filter_func(data): - await original_callback(data) # Use original, not final_callback - - final_callback = filtered_callback + # Content filtering / typed-subscription handling is shared with the Redis + # and Kafka transports (the consume loop below already hands callbacks a + # decoded dict, the same input the wrapper expects). This keeps validation + # and typed delivery identical across all transports and is the single place + # that rejects invalid filter-option combinations. Tracing stays outermost. + final_callback = wrap_handler_with_filters( + callback, + data_type=kwargs.get("data_type"), + filter_by_data=kwargs.get("filter_by_data"), + filter_by_message=kwargs.get("filter_by_message"), + ) from eggai.tracing import make_tracing_wrapper diff --git a/sdk/eggai/transport/middleware_utils.py b/sdk/eggai/transport/middleware_utils.py index ae44e572..f565abff 100644 --- a/sdk/eggai/transport/middleware_utils.py +++ b/sdk/eggai/transport/middleware_utils.py @@ -10,11 +10,44 @@ the behaviour identical across the Kafka, Redis, and in-memory transports. """ +import inspect from collections.abc import Callable from typing import Any from pydantic import ValidationError +# Identity attributes copied from the user handler onto a wrapper. We intentionally +# do NOT use functools.wraps here: it sets __wrapped__, which inspect.signature +# follows — FastStream would then introspect the *original* handler's signature +# (e.g. ``order: OrderMessage``) and try to decode the message into that type +# itself, before our wrapper runs. Our wrappers must keep their own ``(message)`` +# signature so FastStream hands them the raw dict to validate/filter. Copying just +# these attributes preserves the handler's name/docs for logging and AsyncAPI +# without changing what FastStream decodes. +_IDENTITY_ATTRS = ("__module__", "__name__", "__qualname__", "__doc__") + + +def _carry_identity(wrapper: Callable, handler: Callable) -> Callable: + for attr in _IDENTITY_ATTRS: + try: + setattr(wrapper, attr, getattr(handler, attr)) + except AttributeError: + pass + return wrapper + + +async def _invoke(handler: Callable, arg: Any) -> Any: + """Call ``handler`` with ``arg``, awaiting the result only if it is awaitable. + + Mirrors ``make_tracing_wrapper``'s tolerance of synchronous handlers: a sync + handler combined with a filter option must not raise ``TypeError`` from an + unconditional ``await``. + """ + result = handler(arg) + if inspect.isawaitable(result): + return await result + return result + def wrap_handler_with_filters( handler: Callable, @@ -42,8 +75,9 @@ def wrap_handler_with_filters( ``filter_by_message`` and ``data_type`` are mutually exclusive: the former is the untyped (raw-dict) filter, the latter validates into a typed model and - pairs with ``filter_by_data``. Supplying both is rejected rather than silently - dropping one of them. + pairs with ``filter_by_data``. ``filter_by_data`` requires ``data_type``. + Invalid combinations raise ``ValueError`` rather than silently dropping an + option. """ if data_type is not None and filter_by_message is not None: raise ValueError( @@ -51,8 +85,18 @@ def wrap_handler_with_filters( "(which receives the validated typed message) to filter typed " "subscriptions, or filter_by_message on its own for raw-dict filtering." ) + if filter_by_data is not None and data_type is None: + raise ValueError( + "filter_by_data requires data_type — it receives the validated typed " + "message. Use filter_by_message to filter on the raw dict instead." + ) if data_type is not None: + if "type" not in data_type.model_fields: + raise ValueError( + f"data_type {data_type.__name__!r} must define a 'type' field " + "(the discriminator used to match messages, as on BaseMessage)." + ) expected_type = data_type.model_fields["type"].default async def typed_handler(message: dict[str, Any]) -> Any: @@ -65,17 +109,17 @@ async def typed_handler(message: dict[str, Any]) -> Any: return None if filter_by_data is not None and not filter_by_data(typed_message): return None - return await handler(typed_message) + return await _invoke(handler, typed_message) - return typed_handler + return _carry_identity(typed_handler, handler) if filter_by_message is not None: async def filtered_handler(message: dict[str, Any]) -> Any: if filter_by_message(message): - return await handler(message) + return await _invoke(handler, message) return None - return filtered_handler + return _carry_identity(filtered_handler, handler) return handler diff --git a/sdk/tests/test_inmemory_filters.py b/sdk/tests/test_inmemory_filters.py new file mode 100644 index 00000000..4006a216 --- /dev/null +++ b/sdk/tests/test_inmemory_filters.py @@ -0,0 +1,136 @@ +"""Broker-independent tests for InMemoryTransport filtering / typed subscriptions. + +The existing `test_inmemory_transport.py` is a runnable demo (no test functions), +and the filter/data_type scenarios elsewhere use KafkaTransport — so they are +skipped when no broker is available. That left InMemory's filtering untested, +which is how its data_type-presence bugs went unnoticed. These tests run with no +broker and assert InMemory shares the same filtering contract as Redis/Kafka +(both now go through wrap_handler_with_filters). +""" + +import asyncio +import uuid + +import pytest +from pydantic import BaseModel + +from eggai import Agent, Channel +from eggai.schemas import BaseMessage +from eggai.transport import InMemoryTransport + + +class Order(BaseModel): + order_id: int + status: str + + +class OrderMessage(BaseMessage[Order]): + type: str = "OrderMessage" + + +class PaymentMessage(BaseMessage[Order]): + type: str = "PaymentMessage" + + +async def _collect(subscribe_kwargs, messages, *, settle=0.1): + """Subscribe a handler on a fresh in-memory channel, publish messages, and + return what the handler received.""" + transport = InMemoryTransport() + agent = Agent(f"agent-{uuid.uuid4().hex[:8]}", transport=transport) + channel = Channel(f"chan-{uuid.uuid4().hex[:8]}", transport=transport) + + received = [] + done = asyncio.Event() + + @agent.subscribe(channel=channel, **subscribe_kwargs) + async def handler(message): + received.append(message) + done.set() + + await agent.start() + for msg in messages: + await channel.publish(msg) + await asyncio.sleep(settle) + await agent.stop() + return received + + +@pytest.mark.asyncio +async def test_inmemory_filter_by_message_routes_dicts(): + received = await _collect( + {"filter_by_message": lambda m: m.get("type") == "keep"}, + [{"type": "drop", "v": 1}, {"type": "keep", "v": 2}], + ) + assert [m["type"] for m in received] == ["keep"] + + +@pytest.mark.asyncio +async def test_inmemory_data_type_delivers_typed_instance(): + received = await _collect( + {"data_type": OrderMessage}, + [ + PaymentMessage(source="t", data=Order(order_id=1, status="x")), + OrderMessage(source="t", data=Order(order_id=42, status="new")), + ], + ) + assert len(received) == 1 + assert isinstance(received[0], OrderMessage) # typed, not dict + assert received[0].data.order_id == 42 + + +@pytest.mark.asyncio +async def test_inmemory_filter_by_data_narrows_typed(): + received = await _collect( + { + "data_type": OrderMessage, + "filter_by_data": lambda o: o.data.status == "shipped", + }, + [ + OrderMessage(source="t", data=Order(order_id=1, status="new")), + OrderMessage(source="t", data=Order(order_id=2, status="shipped")), + ], + ) + assert [o.data.order_id for o in received] == [2] + + +@pytest.mark.asyncio +async def test_inmemory_data_type_none_with_filter_by_message_filters_raw(): + """P2: data_type=None must behave like 'no data_type' — raw filtering works, + not a raised error or a skip-everything typed branch.""" + received = await _collect( + {"data_type": None, "filter_by_message": lambda m: m.get("keep") is True}, + [{"keep": False}, {"keep": True}], + ) + assert received == [{"keep": True}] + + +@pytest.mark.asyncio +async def test_inmemory_data_type_none_alone_delivers_all(): + """P2: data_type=None alone must not enter the typed branch and silently skip + every message.""" + received = await _collect({"data_type": None}, [{"a": 1}, {"a": 2}]) + assert received == [{"a": 1}, {"a": 2}] + + +@pytest.mark.asyncio +async def test_inmemory_rejects_invalid_filter_combinations(): + transport = InMemoryTransport() + + async def handler(m): + return m + + with pytest.raises(ValueError, match="cannot be combined with data_type"): + await transport.subscribe( + "c", handler, data_type=OrderMessage, filter_by_message=lambda m: True + ) + + with pytest.raises(ValueError, match="filter_by_data requires data_type"): + await transport.subscribe("c", handler, filter_by_data=lambda o: True) + + # P3: the missing-'type'-field validation now applies to InMemory too (it + # goes through the same shared wrapper), instead of a delivery-time KeyError. + class NoType(BaseModel): + value: int + + with pytest.raises(ValueError, match="must define a 'type' field"): + await transport.subscribe("c", handler, data_type=NoType) diff --git a/sdk/tests/test_middleware_utils.py b/sdk/tests/test_middleware_utils.py index 78f181d8..3c946fc4 100644 --- a/sdk/tests/test_middleware_utils.py +++ b/sdk/tests/test_middleware_utils.py @@ -95,6 +95,76 @@ async def handler(order): assert seen == [] +def test_filter_by_data_without_data_type_is_rejected(): + """filter_by_data operates on the typed message, so it requires data_type; + supplying it alone must raise rather than be silently dropped.""" + + async def handler(m): + return m + + with pytest.raises(ValueError, match="filter_by_data requires data_type"): + wrap_handler_with_filters(handler, filter_by_data=lambda o: True) + + +def test_data_type_without_type_field_is_rejected(): + """A data_type model with no 'type' discriminator gets a clear error, not a + bare KeyError from model_fields['type'].""" + + class NoType(BaseModel): + value: int + + async def handler(m): + return m + + with pytest.raises(ValueError, match="must define a 'type' field"): + wrap_handler_with_filters(handler, data_type=NoType) + + +@pytest.mark.asyncio +async def test_sync_handler_works_with_filter_by_message(): + """A synchronous handler combined with a filter must not raise from an + unconditional await (regression vs the old middleware path).""" + seen = [] + + def sync_handler(m): # note: not async + seen.append(m) + + wrapped = wrap_handler_with_filters(sync_handler, filter_by_message=lambda m: True) + await wrapped({"type": "x"}) + + assert seen == [{"type": "x"}] + + +@pytest.mark.asyncio +async def test_sync_handler_works_with_data_type(): + """Sync handler in the typed path is invoked without awaiting a non-awaitable.""" + seen = [] + + def sync_handler(order): # not async + seen.append(order.data.order_id) + + wrapped = wrap_handler_with_filters(sync_handler, data_type=OrderMessage) + await wrapped(_order_msg(order_id=7)) + + assert seen == [7] + + +def test_wrapper_preserves_handler_identity(): + """The wrapper keeps the original handler's name/doc (for logging/AsyncAPI) + without exposing its signature to FastStream's decoder.""" + import inspect + + async def handle_order(order): + """Handle an order.""" + + wrapped = wrap_handler_with_filters(handle_order, data_type=OrderMessage) + + assert wrapped.__name__ == "handle_order" + assert wrapped.__doc__ == "Handle an order." + # Signature stays (message) so FastStream still delivers the raw dict. + assert list(inspect.signature(wrapped).parameters) == ["message"] + + def test_data_type_with_filter_by_message_is_rejected(): """Combining data_type with filter_by_message is a footgun (one would be silently dropped), so it must raise rather than quietly ignore the predicate."""