diff --git a/.github/workflows/lint-test.yaml b/.github/workflows/lint-test.yaml index 48297d09..ea9291bc 100644 --- a/.github/workflows/lint-test.yaml +++ b/.github/workflows/lint-test.yaml @@ -58,7 +58,7 @@ jobs: if: always() - name: Code complexity - run: uv run xenon --max-absolute B --max-modules A --max-average A plugboard/ + run: uv run xenon --max-absolute C --max-modules A --max-average A plugboard/ if: always() - name: Notebook output cleared diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 98eb886c..e7b6bf93 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -18,6 +18,9 @@ from ._validator_registry import validator +_SYSTEM_STOP_EVENT = "system_stop" + + def _build_component_graph( connectors: dict[str, dict[str, _t.Any]], ) -> dict[str, set[str]]: @@ -100,6 +103,9 @@ def validate_all_inputs_connected( all_inputs = set(io.get("inputs", [])) connected = connected_inputs.get(comp_name, set()) unconnected = all_inputs - connected + if unconnected: + event_covered_fields = set().union(*io.get("event_field_coverage", {}).values()) + unconnected -= event_covered_fields if unconnected: errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") return errors diff --git a/plugboard/component/component.py b/plugboard/component/component.py index 6fe0ad20..933170e9 100644 --- a/plugboard/component/component.py +++ b/plugboard/component/component.py @@ -95,6 +95,7 @@ def __init__( initial_values=self._initial_values, input_events=self.__class__.io.input_events, output_events=self.__class__.io.output_events, + event_field_coverage=self.__class__.io.event_field_coverage, namespace=self.name, component=self, ) @@ -143,10 +144,9 @@ def parameters(self) -> dict[str, _t.Any]: return self._parameters @classmethod - def _configure_io(cls) -> None: - # Get all parent classes that are Component subclasses + def _get_aggregated_io_args(cls) -> tuple[dict[str, set], list[str]]: + """Get combined set of all io arguments and exports from this class and all parents.""" parent_comps = cls._get_component_bases() - # Create combined set of all io arguments from this class and all parents io_args: dict[str, set] = defaultdict(set) exports: list[str] = [] for c in parent_comps + [cls]: @@ -157,12 +157,30 @@ def _configure_io(cls) -> None: io_args["output_events"].update(c_io.output_events) if c_exports := getattr(c, "exports"): exports.extend(c_exports) + return io_args, exports + + @classmethod + def _get_event_field_coverage(cls) -> dict[str, list[str]]: + """Get event field coverage from all handlers in this class and all parents.""" + event_field_coverage = {} + for attr_name in dir(cls): + attr = getattr(cls, attr_name, None) + if callable(attr) and hasattr(attr, "_event_field_coverage"): + event_field_coverage.update(attr._event_field_coverage) + return event_field_coverage + + @classmethod + def _configure_io(cls) -> None: + # Get all parent classes that are Component subclasses + io_args, exports = cls._get_aggregated_io_args() + event_field_coverage = cls._get_event_field_coverage() # Set io arguments for subclass cls.io = IO( inputs=sorted(io_args["inputs"], key=str), outputs=sorted(io_args["outputs"], key=str), input_events=sorted(io_args["input_events"], key=str), output_events=sorted(io_args["output_events"], key=str), + event_field_coverage=event_field_coverage, ) # Set exports for subclass cls.exports = sorted(set(exports)) @@ -356,7 +374,7 @@ async def _wrapper() -> None: raise e self._bind_outputs() await self.io.write() - self._field_inputs_ready = False + self._reset_input_trackers() await self._set_status(Status.WAITING, publish=not self._is_running) return _wrapper @@ -365,6 +383,11 @@ async def _wrapper() -> None: def _has_field_inputs(self) -> bool: return len(self.io.inputs) > 0 + @property + def _has_connected_field_inputs(self) -> bool: + """Whether any declared field inputs are connected via input channels.""" + return self.io.has_connected_field_inputs + @cached_property def _has_event_inputs(self) -> bool: input_events = set([evt.safe_type() for evt in self.io.input_events]) @@ -409,7 +432,7 @@ async def _io_read_with_status_check(self) -> None: task.cancel() for task in done: exc = task.exception() - if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0: + if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs: await self.io.close() # Call close for final wait and flush event buffer elif exc is not None: raise exc @@ -422,7 +445,7 @@ async def _periodic_status_check(self) -> None: # TODO : Eventually producer graph update will be event driven. For now, # : the update is performed periodically, so it's called here along # : with the status check. - if len(self.io.inputs) == 0: + if not self._has_connected_field_inputs: await self._update_producer_graph() async def _status_check(self) -> None: @@ -455,8 +478,11 @@ def _bind_inputs(self) -> None: for field in self.io.inputs: field_default = getattr(self, field, None) value = self._field_inputs.get(field, field_default) - setattr(self, field, value) + super().__setattr__(field, value) + + def _reset_input_trackers(self) -> None: self._field_inputs = {} + self._field_inputs_ready = False def _bind_outputs(self) -> None: """Binds component fields to output fields.""" diff --git a/plugboard/component/io_controller.py b/plugboard/component/io_controller.py index 7500aee2..9fa6076a 100644 --- a/plugboard/component/io_controller.py +++ b/plugboard/component/io_controller.py @@ -38,6 +38,7 @@ def __init__( initial_values: _t.Optional[dict[str, _t.Iterable]] = None, input_events: _t.Optional[list[_t.Type[Event]]] = None, output_events: _t.Optional[list[_t.Type[Event]]] = None, + event_field_coverage: _t.Optional[dict[str, list[str]]] = None, namespace: str = IO_NS_UNSET, component: _t.Optional[Component] = None, ) -> None: @@ -47,10 +48,27 @@ def __init__( self.initial_values = initial_values or {} self.input_events = input_events or [] self.output_events = output_events or [] + self.event_field_coverage = event_field_coverage or {} if set(self.initial_values.keys()) - set(self.inputs): raise ValueError("Initial values must be for input fields only.") + self._component = component + self._initial_values = {k: deque(v) for k, v in self.initial_values.items()} + self._input_event_types = {Event.safe_type(evt.type) for evt in self.input_events} + self._output_event_types = {Event.safe_type(evt.type) for evt in self.output_events} + + self._logger = DI.logger.resolve_sync().bind( + cls=self.__class__.__name__, namespace=self.namespace + ) + self._logger.info("IOController created") + + # Initialise channel stores + self._input_channels: dict[tuple[str, str], Channel] = {} + self._output_channels: dict[tuple[str, str], Channel] = {} + self._input_event_channels: dict[str, Channel] = {} + self._output_event_channels: dict[str, Channel] = {} + # Initialise buffers self.buf_fields: dict[str, IOBuffer] = { _io_key_in: IOFieldBuffer(), _io_key_out: IOFieldBuffer(), @@ -60,21 +78,9 @@ def __init__( _io_key_out: IOEventBuffer(), } - self._input_channels: dict[tuple[str, str], Channel] = {} - self._output_channels: dict[tuple[str, str], Channel] = {} - self._input_event_channels: dict[str, Channel] = {} - self._output_event_channels: dict[str, Channel] = {} - self._input_event_types = {Event.safe_type(evt.type) for evt in self.input_events} - self._output_event_types = {Event.safe_type(evt.type) for evt in self.output_events} - self._initial_values = {k: deque(v) for k, v in self.initial_values.items()} - self._read_tasks: dict[str | _t_field_key, asyncio.Task] = {} + # Initialise orchestration state self._is_closed = False - - self._logger = DI.logger.resolve_sync().bind( - cls=self.__class__.__name__, namespace=self.namespace - ) - self._logger.info("IOController created") - + self._read_tasks: dict[str | _t_field_key, asyncio.Task] = {} self._received_fields: dict[str, _t.Any] = {} self._received_fields_lock = asyncio.Lock() self._received_events: deque[Event] = deque() @@ -86,8 +92,9 @@ def is_closed(self) -> bool: """Returns `True` if the `IOController` is closed, `False` otherwise.""" return self._is_closed - @cached_property - def _has_field_inputs(self) -> bool: + @property + def has_connected_field_inputs(self) -> bool: + """Returns whether any field inputs are connected via channels.""" return len(self._input_channels) > 0 @cached_property @@ -96,7 +103,7 @@ def _has_event_inputs(self) -> bool: @cached_property def _has_inputs(self) -> bool: - return self._has_field_inputs or self._has_event_inputs + return self.has_connected_field_inputs or self._has_event_inputs async def read(self, timeout: float | None = None) -> None: """Reads data and/or events from input channels. @@ -139,7 +146,7 @@ async def read(self, timeout: float | None = None) -> None: def _set_read_tasks(self) -> list[asyncio.Task]: read_tasks: list[asyncio.Task] = [] - if self._has_field_inputs: + if self.has_connected_field_inputs: if _fields_read_task not in self._read_tasks: read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task) self._read_tasks[_fields_read_task] = read_fields_task @@ -374,7 +381,7 @@ def _add_channel_for_event( def _create_input_field_group_tasks(self) -> None: """Groups input field channels by field name and launches read tasks for group inputs.""" - if not self._has_field_inputs: + if not self.has_connected_field_inputs: return field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list) for key, chan in self._input_channels.items(): @@ -410,6 +417,7 @@ def dict(self) -> dict[str, _t.Any]: # noqa: D102 "input_events": [e.safe_type() for e in self.input_events], "output_events": [e.safe_type() for e in self.output_events], "initial_values": {k: list(v) for k, v in self._initial_values.items()}, + "event_field_coverage": {k: list(v) for k, v in self.event_field_coverage.items()}, } diff --git a/plugboard/events/event.py b/plugboard/events/event.py index 2becb56d..858e26a5 100644 --- a/plugboard/events/event.py +++ b/plugboard/events/event.py @@ -75,9 +75,28 @@ def safe_type(cls, event_type: _t.Optional[str] = None) -> str: """Returns a safe event type string for use in broker topic strings.""" return (event_type or cls.type).replace(".", "_").replace("-", "_") + @_t.overload @classmethod - def handler(cls, method: AsyncCallable) -> AsyncCallable: + def handler(cls, method: AsyncCallable) -> AsyncCallable: ... + + @_t.overload + @classmethod + def handler( + cls, *, populates_fields: _t.Optional[list[str]] = None + ) -> _t.Callable[[AsyncCallable], AsyncCallable]: ... + + @classmethod + def handler( + cls, + method: _t.Optional[AsyncCallable] = None, + *, + populates_fields: _t.Optional[list[str]] = None, + ) -> _t.Union[AsyncCallable, _t.Callable[[AsyncCallable], AsyncCallable]]: """Registers a class method as an event handler.""" + if method is None: + # Invoked as @Event.handler(populates_fields=[...]) + return EventHandlers.add(cls, populates_fields=populates_fields) + # Invoked as @Event.handler return EventHandlers.add(cls)(method) diff --git a/plugboard/events/event_handlers.py b/plugboard/events/event_handlers.py index 344522ce..a92003d9 100644 --- a/plugboard/events/event_handlers.py +++ b/plugboard/events/event_handlers.py @@ -18,11 +18,16 @@ class EventHandlers: # pragma: no cover _handlers: _t.ClassVar[dict[str, dict[str, AsyncCallable]]] = defaultdict(dict) @classmethod - def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], AsyncCallable]: + def add( + cls, + event: _t.Type[Event] | Event, + populates_fields: _t.Optional[list[str]] = None, + ) -> _t.Callable[[AsyncCallable], AsyncCallable]: """Decorator that registers class methods as handlers for specific event types. Args: event: Event class this handler processes + populates_fields: Optional list of fields that the handler populates Returns: Callable: Decorated method @@ -31,6 +36,12 @@ def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], Asyn def decorator(method: AsyncCallable) -> AsyncCallable: class_path = cls._get_class_path_for_method(method) cls._handlers[class_path][event.type] = method + + if populates_fields is not None: + if not hasattr(method, "_event_field_coverage"): + setattr(method, "_event_field_coverage", {}) + getattr(method, "_event_field_coverage")[event.type] = populates_fields + return method return decorator @@ -57,10 +68,11 @@ def get(cls, _class: _t.Type, event: _t.Type[Event] | Event) -> AsyncCallable: Raises: KeyError: If no handler found for class or event type """ + store = cls._handlers for base_class in _class.__mro__: base_path = f"{base_class.__module__}.{base_class.__name__}" - if base_path in cls._handlers and event.type in cls._handlers[base_path]: - return cls._handlers[base_path][event.type] + if base_path in store and event.type in store[base_path]: + return store[base_path][event.type] raise KeyError( f"No handler found for class '{_class.__name__}' and event type '{event.type}'" ) diff --git a/plugboard/library/data_writer.py b/plugboard/library/data_writer.py index 96d14538..598b39b3 100644 --- a/plugboard/library/data_writer.py +++ b/plugboard/library/data_writer.py @@ -50,6 +50,7 @@ def __init__( outputs=None, input_events=self.__class__.io.input_events, output_events=self.__class__.io.output_events, + event_field_coverage=self.__class__.io.event_field_coverage, namespace=self.name, component=self, ) @@ -76,18 +77,39 @@ async def _convert(self, data: dict[str, deque]) -> _t.Any: def _bind_inputs(self) -> None: """Binds input fields to component fields and append to internal buffer.""" super()._bind_inputs() - for field in self.io.inputs: + for field in self._field_inputs: value = getattr(self, field, None) self._buffer[field].append(value) + @property + def _completed_rows(self) -> int: + """Calculates how many fully formed rows exist in the buffer.""" + if not self.io.inputs: + return 0 + return min((len(self._buffer[f]) for f in self.io.inputs), default=0) + + @property + def _can_step(self) -> bool: + """We can step if we have at least one fully formed row.""" + return self._completed_rows > 0 + async def _save_chunk(self) -> None: - """Write data from the buffer.""" + """Write completed data rows from the buffer.""" + completed_rows = self._completed_rows + if completed_rows == 0: + return + if self._task is not None: await self._task - # Create task to save next chunk of data - chunk = await self._convert(self._buffer) + + # Extract only the completed rows into a new chunk + chunk_data = { + field: deque([self._buffer[field].popleft() for _ in range(completed_rows)]) + for field in self.io.inputs + } + + chunk = await self._convert(chunk_data) self._task = asyncio.create_task(self._save(chunk)) - self._buffer = defaultdict(deque) async def step(self) -> None: """Trigger save when buffer is at target size.""" diff --git a/tests/integration/test_component_event_handlers.py b/tests/integration/test_component_event_handlers.py index 8a27ce34..1ba77fe5 100644 --- a/tests/integration/test_component_event_handlers.py +++ b/tests/integration/test_component_event_handlers.py @@ -11,6 +11,7 @@ from plugboard.component import Component, IOController from plugboard.connector import AsyncioConnector, Connector, ConnectorBuilder from plugboard.events import Event +from plugboard.events.event import StopEvent from plugboard.schemas import ConnectorSpec from tests.conftest import zmq_connector_cls @@ -272,3 +273,106 @@ class _A(A): assert getattr(a, "in_2", None) == 6 await a.io.close() + + +class B(Component): + """B test component.""" + + io = IOController(inputs=["a", "b"], input_events=[EventTypeA, EventTypeB]) + + def __init__(self: _t.Self, *args: _t.Any, **kwargs: _t.Any) -> None: + super().__init__(*args, **kwargs) + self.hist_a: list[int] = [] + self.hist_b: list[int] = [] + + async def step(self) -> None: + """A test step.""" + pass + + @EventTypeA.handler(populates_fields=["a"]) + async def event_A_handler(self, evt: EventTypeA) -> None: + """A test event handler.""" + self.a = evt.data.x + self.hist_a.append(self.a) + + @EventTypeB.handler(populates_fields=["b"]) + async def event_B_handler(self, evt: EventTypeB) -> None: + """A test event handler.""" + self.b = evt.data.y + self.hist_b.append(self.b) + + +async def test_component_event_handlers_populates_fields( + connector_builder: ConnectorBuilder, +) -> None: + """Test that event handlers can populate fields for components.""" + b = B(name="b") + + assert b.io.event_field_coverage == { + EventTypeA.safe_type(): ["a"], + EventTypeB.safe_type(): ["b"], + } + + assert b.io.dict() == { + "namespace": "b", + "inputs": ["a", "b"], + "outputs": [], + "input_events": [StopEvent.safe_type(), EventTypeA.safe_type(), EventTypeB.safe_type()], + "output_events": [StopEvent.safe_type()], + "event_field_coverage": { + EventTypeA.safe_type(): ["a"], + EventTypeB.safe_type(): ["b"], + }, + "initial_values": {}, + } + + connectors = connector_builder.build_event_connectors([b]) + event_connectors_map = {conn.spec.source.entity: conn for conn in connectors} + + await b.io.connect(connectors) + + assert b.hist_a == [] + assert b.hist_b == [] + assert getattr(b, "a", None) is None + assert getattr(b, "b", None) is None + + chan_A = await event_connectors_map[EventTypeA.safe_type()].connect_send() + chan_B = await event_connectors_map[EventTypeB.safe_type()].connect_send() + + evt_A = EventTypeA(data=EventTypeAData(x=2), source="test-driver") + await chan_A.send(evt_A) + await b.step() + + assert b.hist_a == [2] + assert b.hist_b == [] + assert getattr(b, "a", None) == 2 + assert getattr(b, "b", None) is None + + evt_B = EventTypeB(data=EventTypeBData(y=4), source="test-driver") + await chan_B.send(evt_B) + await b.step() + + assert b.hist_a == [2] + assert b.hist_b == [4] + assert getattr(b, "a", None) == 2 + assert getattr(b, "b", None) == 4 + + evt_A = EventTypeA(data=EventTypeAData(x=3), source="test-driver") + await chan_A.send(evt_A) + await b.step() + + assert b.hist_a == [2, 3] + assert b.hist_b == [4] + assert getattr(b, "a", None) == 3 + assert getattr(b, "b", None) == 4 + + evt_B = EventTypeB(data=EventTypeBData(y=5), source="test-driver") + await chan_B.send(evt_B) + await b.step() + + assert b.hist_a == [2, 3] + assert b.hist_b == [4, 5] + assert getattr(b, "a", None) == 3 + assert getattr(b, "b", None) == 5 + + await b.io.close() diff --git a/tests/integration/test_process_with_components_run.py b/tests/integration/test_process_with_components_run.py index fe047ae8..ca599d1e 100644 --- a/tests/integration/test_process_with_components_run.py +++ b/tests/integration/test_process_with_components_run.py @@ -23,6 +23,7 @@ ) from plugboard.events import Event from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError +from plugboard.library import FileWriter from plugboard.process import LocalProcess, Process, RayProcess from plugboard.schemas import ConnectorSpec, Status from tests.conftest import ComponentTestHelper, zmq_connector_cls @@ -459,6 +460,99 @@ async def test_event_driven_process_shutdown( await process.destroy() +class MessageEventData(BaseModel): + """Data for a message event.""" + + message: str + + +class MessageEvent(Event): + """Event carrying a file-writer message.""" + + type: _t.ClassVar[str] = "message_event" + data: MessageEventData + + +class MessageEventGenerator(ComponentTestHelper): + """Produces a fixed number of message events.""" + + io = IO(output_events=[MessageEvent]) + + def __init__( + self, + iters: int, + *args: _t.Any, + delay: float = 0.0, + start: int = 0, + stride: int = 1, + **kwargs: _t.Any, + ) -> None: + super().__init__(*args, **kwargs) + self._iters = iters + self._delay = delay + self._start = start + self._stride = stride + + async def init(self) -> None: + await super().init() + self._seq = iter(range(self._start, self._start + self._iters * self._stride, self._stride)) + + async def step(self) -> None: + # Optional delay to simulate staggered event arrival + if self._delay > 0.0: + await asyncio.sleep(self._delay) + try: + idx = next(self._seq) + except StopIteration: + await self.io.close() + else: + evt = MessageEvent( + source=self.name, + data=MessageEventData(message=f"Message {idx}"), + ) + self.io.queue_event(evt) + await super().step() + + +class EventReaderFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + @MessageEvent.handler(populates_fields=["message"]) + async def handle_message(self, event: MessageEvent) -> None: + self.message = event.data.message + + +@pytest.mark.asyncio +async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None: + """Test that field-input components can be reused in event-driven processes.""" + output_path = tmp_path / "output_messages.csv" + components = [ + MessageEventGenerator(iters=3, name="message_event_generator"), + EventReaderFileWriter( + path=output_path, + name="event_reader_file_writer", + field_names=["message"], + ), + ] + event_connectors = AsyncioConnector.builder().build_event_connectors(components) + process = LocalProcess(components=components, connectors=event_connectors) + + await process.init() + await process.run() + + assert process.status == Status.COMPLETED + assert output_path.read_text().splitlines() == [ + "message", + "Message 0", + "Message 1", + "Message 2", + ] + + await process.destroy() + + _SHORT_TIMEOUT = 0.1 @@ -536,3 +630,88 @@ async def test_constraint_error_stops_background_status_check() -> None: ) await process.destroy() + + +class StaggeredEventFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + def __init__(self, *args: _t.Any, field_names: list[str], **kwargs: _t.Any) -> None: + super().__init__(*args, field_names=field_names, **kwargs) + self.step_count: int = 0 + self.step_for_message: dict[str, int] = {} + + @MessageEvent.handler(populates_fields=["mg1", "mg2", "mg3"]) + async def handle_message(self, event: MessageEvent) -> None: + msg = event.data.message + match event.source: + case "mg1": + self.mg1 = msg + case "mg2": + self.mg2 = msg + case "mg3": + self.mg3 = msg + case _: + raise ValueError(f"Unexpected event source: {event.source}") + self.step_for_message[msg] = self.step_count + self.step_count += 1 + + +@pytest.mark.asyncio +@pytest_cases.parametrize( + "process_cls, connector_cls", + [ + (LocalProcess, AsyncioConnector), + ], +) +async def test_data_writer_handles_staggered_input_events( + process_cls: type[Process], connector_cls: type[Connector], tmp_path: Path, ray_ctx: None +) -> None: + """Test that a FileWriter can handle input events arriving in different steps. + + Input messages with data for different fields may arrive in different steps. The FileWriter + should write out a new row only when all required fields have received data, and should not + overwrite field values if only a subset of fields receive new data in a step. + """ + output_path = tmp_path / "staggered_output_messages.csv" + + writer = StaggeredEventFileWriter( + path=output_path, field_names=["mg1", "mg2", "mg3"], name="writer" + ) + components = [ + # 3 inputs with different delays + MessageEventGenerator(iters=10, delay=0.005, start=0, stride=3, name="mg1"), + MessageEventGenerator(iters=10, delay=0.010, start=1, stride=3, name="mg2"), + MessageEventGenerator(iters=10, delay=0.020, start=2, stride=3, name="mg3"), + writer, + ] + + async with process_cls( + components=components, + connectors=AsyncioConnector.builder().build_event_connectors(components), + ) as process: + await process.run() + + with output_path.open() as f: + content = f.read().splitlines() + + assert len(content) == 11 # header + 10 rows of data + assert content[0] == "mg1,mg2,mg3" + assert content[1] == "Message 0,Message 1,Message 2" + assert content[2] == "Message 3,Message 4,Message 5" + assert content[3] == "Message 6,Message 7,Message 8" + assert content[4] == "Message 9,Message 10,Message 11" + assert content[5] == "Message 12,Message 13,Message 14" + assert content[6] == "Message 15,Message 16,Message 17" + assert content[7] == "Message 18,Message 19,Message 20" + assert content[8] == "Message 21,Message 22,Message 23" + assert content[9] == "Message 24,Message 25,Message 26" + assert content[10] == "Message 27,Message 28,Message 29" + + # Verify that messages from different generators were received in different steps + assert writer.step_count == 30 + assert len(writer.step_for_message) == 30 + assert len(set(writer.step_for_message.values())) == 30, ( + "Expected each message to be received in a different step" + ) diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py index 02e0a4d2..df20132e 100644 --- a/tests/unit/test_process_validation.py +++ b/tests/unit/test_process_validation.py @@ -95,6 +95,7 @@ def _make_component( outputs: list[str] | None = None, input_events: list[str] | None = None, output_events: list[str] | None = None, + event_field_coverage: dict[str, list[str]] | None = None, initial_values: dict[str, _t.Any] | None = None, ) -> dict[str, _t.Any]: """Build a component dict matching process.dict() format.""" @@ -108,6 +109,7 @@ def _make_component( "outputs": outputs or [], "input_events": input_events or [], "output_events": output_events or [], + "event_field_coverage": event_field_coverage or {}, "initial_values": initial_values or {}, }, } @@ -303,6 +305,22 @@ def test_no_inputs_no_errors(self) -> None: errors = validate_all_inputs_connected(pd) assert errors == [] + def test_event_covered_fields(self) -> None: + """Unconnected inputs are allowed when non-system input events can populate them.""" + pd = _make_process_dict( + components={ + "producer": _make_component("producer", output_events=["message_event"]), + "writer": _make_component( + "writer", + inputs=["message"], + input_events=["system_stop", "message_event"], + event_field_coverage={"message_event": ["message"]}, + ), + }, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + # --------------------------------------------------------------------------- # Tests for validate_input_events