From c19abb2593edcc7113ae6588492124949253813b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 17:31:16 -0700 Subject: [PATCH 1/2] feat: Implement direct device trait updates from data protocol messages using `dps` metadata and add corresponding update listeners. This uses the same dps converter patern used by q10, but does not share code explicitly. --- roborock/data/v1/v1_containers.py | 21 ++--- roborock/devices/device.py | 4 +- roborock/devices/device_manager.py | 1 + roborock/devices/rpc/v1_channel.py | 34 +++++-- roborock/devices/traits/v1/__init__.py | 28 +++++- roborock/devices/traits/v1/common.py | 76 +++++++++++++++- roborock/devices/traits/v1/consumeable.py | 18 +++- roborock/devices/traits/v1/status.py | 20 ++++- roborock/protocols/v1_protocol.py | 63 ++++++++++--- roborock/roborock_message.py | 2 +- tests/devices/rpc/test_v1_channel.py | 4 +- tests/devices/test_v1_device.py | 1 + tests/devices/traits/v1/fixtures.py | 3 +- tests/devices/traits/v1/test_status.py | 76 ++++++++++++++++ .../__snapshots__/test_device_manager.ambr | 8 ++ tests/protocols/test_v1_protocol.py | 88 ++++++++++++++++++- 16 files changed, 407 insertions(+), 40 deletions(-) diff --git a/roborock/data/v1/v1_containers.py b/roborock/data/v1/v1_containers.py index cf988590..5c227bcb 100644 --- a/roborock/data/v1/v1_containers.py +++ b/roborock/data/v1/v1_containers.py @@ -37,6 +37,7 @@ ROBOROCK_G20S_Ultra, ) from roborock.exceptions import RoborockException +from roborock.roborock_message import RoborockDataProtocol from ..containers import NamedRoomMapping, RoborockBase, RoborockBaseTimer, _attr_repr from .v1_clean_modes import WashTowelModes @@ -293,11 +294,11 @@ class StatusV2(RoborockBase): msg_ver: int | None = None msg_seq: int | None = None - state: RoborockStateCode | None = None - battery: int | None = None + state: RoborockStateCode | None = field(default=None, metadata={"dps": RoborockDataProtocol.STATE}) + battery: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.BATTERY}) clean_time: int | None = None clean_area: int | None = None - error_code: RoborockErrorCode | None = None + error_code: RoborockErrorCode | None = field(default=None, metadata={"dps": RoborockDataProtocol.ERROR_CODE}) map_present: int | None = None in_cleaning: RoborockInCleaning | None = None in_returning: int | None = None @@ -307,12 +308,12 @@ class StatusV2(RoborockBase): back_type: int | None = None wash_phase: int | None = None wash_ready: int | None = None - fan_power: int | None = None + fan_power: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.FAN_POWER}) dnd_enabled: int | None = None map_status: int | None = None is_locating: int | None = None lock_status: int | None = None - water_box_mode: int | None = None + water_box_mode: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.WATER_BOX_MODE}) water_box_carriage_status: int | None = None mop_forbidden_enable: int | None = None camera_status: int | None = None @@ -330,13 +331,13 @@ class StatusV2(RoborockBase): collision_avoid_status: int | None = None switch_map_mode: int | None = None dock_error_status: RoborockDockErrorCode | None = None - charge_status: int | None = None + charge_status: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.CHARGE_STATUS}) unsave_map_reason: int | None = None unsave_map_flag: int | None = None wash_status: int | None = None distance_off: int | None = None in_warmup: int | None = None - dry_status: int | None = None + dry_status: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.DRYING_STATUS}) rdt: int | None = None clean_percent: int | None = None rss: int | None = None @@ -626,9 +627,9 @@ class CleanSummaryWithDetail(CleanSummary): @dataclass class Consumable(RoborockBase): - main_brush_work_time: int | None = None - side_brush_work_time: int | None = None - filter_work_time: int | None = None + main_brush_work_time: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.MAIN_BRUSH_WORK_TIME}) + side_brush_work_time: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.SIDE_BRUSH_WORK_TIME}) + filter_work_time: int | None = field(default=None, metadata={"dps": RoborockDataProtocol.FILTER_WORK_TIME}) filter_element_work_time: int | None = None sensor_dirty_time: int | None = None strainer_work_times: int | None = None diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 29f1fd28..bf020814 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -199,7 +199,7 @@ async def connect(self) -> None: unsub = await self._channel.subscribe(self._on_message) try: if self.v1_properties is not None: - await self.v1_properties.discover_features() + await self.v1_properties.start() elif self.b01_q10_properties is not None: await self.b01_q10_properties.start() except RoborockException: @@ -216,6 +216,8 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self.v1_properties is not None: + self.v1_properties.close() if self.b01_q10_properties is not None: await self.b01_q10_properties.close() if self._unsub: diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index b1ef6626..d4d64fb1 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -236,6 +236,7 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat channel.rpc_channel, channel.mqtt_rpc_channel, channel.map_rpc_channel, + channel.add_dps_listener, web_api, device_cache=device_cache, map_parser_config=map_parser_config, diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index d1b4ee24..81c3466d 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import Any, TypeVar +from roborock.callbacks import CallbackList from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData from roborock.devices.cache import DeviceCache from roborock.devices.transport.channel import Channel @@ -30,9 +31,10 @@ V1RpcChannel, create_map_response_decoder, create_security_data, + decode_data_protocol_message, decode_rpc_response, ) -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from roborock.util import RoborockLoggerAdapter @@ -188,6 +190,7 @@ def __init__( self._device_cache = device_cache self._reconnect_task: asyncio.Task[None] | None = None self._last_network_info_refresh: datetime.datetime | None = None + self._dps_listeners = CallbackList[dict[RoborockDataProtocol, Any]]() @property def is_connected(self) -> bool: @@ -305,12 +308,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - if not self.is_local_connected: - # We were not able to connect locally, so fallback to MQTT and at least - # establish that connection explicitly. If this fails then raise an - # error and let the caller know we failed to subscribe. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - self._logger.debug("V1Channel connected to device via MQTT") + # Always subscribe to MQTT to receive protocol updates (data points) + # even if we have a local connection. Protocol updates only come via cloud/MQTT. + # Local connection is used for RPC commands, but push notifications come via MQTT. + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self.is_local_connected: + self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") + else: + self._logger.debug("V1Channel connected via MQTT only") def unsub() -> None: """Unsubscribe from all messages.""" @@ -328,6 +333,16 @@ def unsub() -> None: self._callback = callback return unsub + def add_dps_listener(self, listener: Callable[[dict[RoborockDataProtocol, Any]], None]) -> Callable[[], None]: + """Add a listener for DPS updates. + + This will attach a listener to the existing subscription, invoking + the listener whenever new DPS values arrive from the subscription. + This will only work if a subscription has already been setup, which is + handled by the device setup. + """ + return self._dps_listeners.add_callback(listener) + async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInfo: """Retrieve networking information for the device. @@ -428,6 +443,11 @@ def _on_mqtt_message(self, message: RoborockMessage) -> None: self._logger.debug("V1Channel received MQTT message: %s", message) if self._callback: self._callback(message) + try: + if datapoints := decode_data_protocol_message(message): + self._dps_listeners(datapoints) + except RoborockException as e: + self._logger.debug("Error decoding data protocol message: %s", e) def _on_local_message(self, message: RoborockMessage) -> None: """Handle incoming local messages.""" diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index ad4b062a..fa5890a0 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -53,6 +53,7 @@ """ import logging +from collections.abc import Callable from dataclasses import dataclass, field, fields from typing import Any, get_args @@ -60,8 +61,10 @@ from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode from roborock.devices.cache import DeviceCache from roborock.devices.traits import Trait +from roborock.exceptions import RoborockException from roborock.map.map_parser import MapParserConfig -from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.protocols.v1_protocol import V1RpcChannel, decode_data_protocol_message +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage from roborock.web_api import UserWebApiClient from . import ( @@ -176,6 +179,7 @@ def __init__( rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel, map_rpc_channel: V1RpcChannel, + add_dps_listener: Callable[[Callable[[dict[RoborockDataProtocol, Any]], None]], Callable[[], None]], web_api: UserWebApiClient, device_cache: DeviceCache, map_parser_config: MapParserConfig | None = None, @@ -189,6 +193,8 @@ def __init__( self._web_api = web_api self._device_cache = device_cache self._region = region + self._unsub: Callable[[], None] | None = None + self._add_dps_listener = add_dps_listener self.device_features = DeviceFeaturesTrait(product, self._device_cache) self.status = StatusTrait(self.device_features, region=self._region) @@ -227,6 +233,24 @@ def _get_rpc_channel(self, trait: V1TraitMixin) -> V1RpcChannel: else: return self._rpc_channel + async def start(self) -> None: + """Start the properties API and discover features.""" + await self.discover_features() + self._unsub = self._add_dps_listener(self._on_dps_update) + + def close(self) -> None: + if self._unsub: + self._unsub() + + def _on_dps_update(self, dps: dict[RoborockDataProtocol, Any]) -> None: + """Handle incoming messages from the device. + + This will notify all traits of the new values. + """ + _LOGGER.debug("Received message from device: %s", dps) + self.status.update_from_dps(dps) + self.consumables.update_from_dps(dps) + async def discover_features(self) -> None: """Populate any supported traits that were not initialized in __init__.""" _LOGGER.debug("Starting optional trait discovery") @@ -330,6 +354,7 @@ def create( rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel, map_rpc_channel: V1RpcChannel, + add_dps_listener: Callable[[Callable[[dict[RoborockDataProtocol, Any]], None]], Callable[[], None]], web_api: UserWebApiClient, device_cache: DeviceCache, map_parser_config: MapParserConfig | None = None, @@ -343,6 +368,7 @@ def create( rpc_channel, mqtt_rpc_channel, map_rpc_channel, + add_dps_listener, web_api, device_cache, map_parser_config, diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index ce2c899e..719a2568 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -5,11 +5,14 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import fields -from typing import ClassVar +from typing import Any, ClassVar +from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand _LOGGER = logging.getLogger(__name__) @@ -173,3 +176,74 @@ def wrapper(*args, **kwargs): cls.map_rpc_channel = True # type: ignore[attr-defined] return wrapper + + +# TODO(allenporter): Merge with roborock.devices.traits.b01.q10.common.TraitUpdateListener +class TraitUpdateListener(ABC): + """Trait update listener. + + This is a base class for traits to support notifying listeners when they + have been updated. Clients may register callbacks to be notified when the + trait has been updated. When the listener callback is invoked, the client + should read the trait's properties to get the updated values. + """ + + def __init__(self, logger: logging.Logger) -> None: + """Initialize the trait update listener.""" + self._update_callbacks: CallbackList[None] = CallbackList(logger=logger) + + def add_update_listener(self, callback: Callable[[], None]) -> Callable[[], None]: + """Register a callback when the trait has been updated. + + Returns a callable to remove the listener. + """ + # We wrap the callback to ignore the value passed to it. + return self._update_callbacks.add_callback(lambda _: callback()) + + def _notify_update(self) -> None: + """Notify all update listeners.""" + self._update_callbacks(None) + + +class DpsDataConverter: + """Utility to handle the transformation and merging of DPS data into models. + + This class pre-calculates the mapping between Data Point IDs and dataclass fields + to optimize repeated updates from device streams. + """ + + def __init__(self, dps_type_map: dict[RoborockDataProtocol, type], dps_field_map: dict[RoborockDataProtocol, str]): + """Initialize the converter for a specific RoborockBase-derived class.""" + self._dps_type_map = dps_type_map + self._dps_field_map = dps_field_map + + @classmethod + def from_dataclass(cls, dataclass_type: type[RoborockBase]): + """Initialize the converter for a specific RoborockBase-derived class.""" + dps_type_map: dict[RoborockDataProtocol, type] = {} + dps_field_map: dict[RoborockDataProtocol, str] = {} + for field_obj in fields(dataclass_type): + if field_obj.metadata and "dps" in field_obj.metadata: + dps_id = field_obj.metadata["dps"] + dps_type_map[dps_id] = field_obj.type + dps_field_map[dps_id] = field_obj.name + return cls(dps_type_map, dps_field_map) + + def update_from_dps(self, target: RoborockBase, decoded_dps: dict[RoborockDataProtocol, Any]) -> bool: + """Convert and merge raw DPS data into the target object. + + Uses the pre-calculated type mapping to ensure values are converted to the + correct Python types before being updated on the target. + + Args: + target: The target object to update. + decoded_dps: The decoded DPS data to convert. + + Returns: + True if any values were updated, False otherwise. + """ + conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) + for dps_id, value in conversions.items(): + field_name = self._dps_field_map[dps_id] + setattr(target, field_name, value) + return bool(conversions) diff --git a/roborock/devices/traits/v1/consumeable.py b/roborock/devices/traits/v1/consumeable.py index 0e59432f..897af012 100644 --- a/roborock/devices/traits/v1/consumeable.py +++ b/roborock/devices/traits/v1/consumeable.py @@ -5,16 +5,21 @@ """ from enum import StrEnum -from typing import Self +from typing import Any, Self from roborock.data import Consumable from roborock.devices.traits.v1 import common +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand +from .common import TraitUpdateListener + __all__ = [ "ConsumableTrait", ] +_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable) + class ConsumableAttribute(StrEnum): """Enum for consumable attributes.""" @@ -33,7 +38,7 @@ def from_str(cls, value: str) -> Self: raise ValueError(f"Unknown ConsumableAttribute: {value}") -class ConsumableTrait(Consumable, common.V1TraitMixin): +class ConsumableTrait(Consumable, common.V1TraitMixin, TraitUpdateListener): """Trait for managing consumable attributes on Roborock devices. After the first refresh, you can tell what consumables are supported by @@ -47,3 +52,12 @@ async def reset_consumable(self, consumable: ConsumableAttribute) -> None: """Reset a specific consumable attribute on the device.""" await self.rpc_channel.send_command(RoborockCommand.RESET_CONSUMABLE, params=[consumable.value]) await self.refresh() + + def update_from_dps(self, decoded_dps: dict[RoborockDataProtocol, Any]) -> None: + """Update the trait from data protocol push message data. + + This handles unsolicited status updates pushed by the device + via RoborockDataProtocol codes (e.g. STATE=121, BATTERY=122). + """ + if _DPS_CONVERTER.update_from_dps(self, decoded_dps): + self._notify_update() diff --git a/roborock/devices/traits/v1/status.py b/roborock/devices/traits/v1/status.py index 82371c15..84cdcb9b 100644 --- a/roborock/devices/traits/v1/status.py +++ b/roborock/devices/traits/v1/status.py @@ -1,4 +1,6 @@ +import logging from functools import cached_property +from typing import Any from roborock import ( CleanRoutes, @@ -10,13 +12,19 @@ get_water_mode_mapping, get_water_modes, ) +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from . import common +from .common import TraitUpdateListener from .device_features import DeviceFeaturesTrait +_LOGGER = logging.getLogger(__name__) -class StatusTrait(StatusV2, common.V1TraitMixin): +_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(StatusV2) + + +class StatusTrait(StatusV2, common.V1TraitMixin, TraitUpdateListener): """Trait for managing the status of Roborock devices. The StatusTrait gives you the access to the state of a Roborock vacuum. @@ -47,6 +55,7 @@ class StatusTrait(StatusV2, common.V1TraitMixin): def __init__(self, device_feature_trait: DeviceFeaturesTrait, region: str | None = None) -> None: """Initialize the StatusTrait.""" super().__init__() + TraitUpdateListener.__init__(self, logger=_LOGGER) self._device_features_trait = device_feature_trait self._region = region @@ -91,3 +100,12 @@ def mop_route_name(self) -> str | None: if self.mop_mode is None: return None return self.mop_route_mapping.get(self.mop_mode) + + def update_from_dps(self, decoded_dps: dict[RoborockDataProtocol, Any]) -> None: + """Update the trait from data protocol push message data. + + This handles unsolicited status updates pushed by the device + via RoborockDataProtocol codes (e.g. STATE=121, BATTERY=122). + """ + if _DPS_CONVERTER.update_from_dps(self, decoded_dps): + self._notify_update() diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 355043c5..14144793 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -15,7 +15,7 @@ from roborock.data import RoborockBase, RRiot from roborock.exceptions import RoborockException, RoborockInvalidStatus, RoborockUnsupportedFeature from roborock.protocol import Utils -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from roborock.util import get_next_int, get_timestamp @@ -24,6 +24,7 @@ __all__ = [ "SecurityData", "create_security_data", + "decode_data_protocol_message", "decode_rpc_response", "V1RpcChannel", ] @@ -139,6 +140,28 @@ class ResponseMessage: """The API error message of the response if any.""" +def _decode_dps_message(message: RoborockMessage) -> dict[int, Any] | None: + """Decode a V1 push message containing data protocol updates.""" + if not message.payload: + return None + try: + payload = json.loads(message.payload.decode()) + except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: + raise RoborockException(f"Invalid V1 message payload: {e} for {message.payload!r}") from e + + datapoints = payload.get("dps") + if not isinstance(datapoints, dict): + return None + result: dict[int, Any] = {} + for key, value in datapoints.items(): + try: + code = int(key) + except (ValueError, TypeError): + continue + result[code] = value + return result if result else None + + def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: """Decode a V1 RPC_RESPONSE message. @@ -147,19 +170,10 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not message.payload: + if not (datapoints := _decode_dps_message(message)): return ResponseMessage(request_id=message.seq, data={}) - try: - payload = json.loads(message.payload.decode()) - except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: - raise RoborockException(f"Invalid V1 message payload: {e} for {message.payload!r}") from e - - _LOGGER.debug("Decoded V1 message payload: %s", payload) - datapoints = payload.get("dps", {}) - if not isinstance(datapoints, dict): - raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}") - if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))): + if not (data_point := datapoints.get(RoborockMessageProtocol.RPC_RESPONSE)): raise RoborockException( f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point" ) @@ -206,6 +220,31 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: return ResponseMessage(request_id=request_id, data=result, api_error=api_error) +def decode_data_protocol_message(message: RoborockMessage) -> dict[RoborockDataProtocol, Any] | None: + """Decode a V1 push message containing data protocol updates. + + V1 devices push unsolicited status updates containing data points keyed + by RoborockDataProtocol codes (e.g., 121=STATE, 122=BATTERY). This function + extracts those data points from the message payload. + + Returns a dict mapping RoborockDataProtocol to values, or None if the + message does not contain any recognized data protocol updates. + """ + if not (datapoints := _decode_dps_message(message)): + return None + + result: dict[RoborockDataProtocol, Any] = {} + for code, value in datapoints.items(): + try: + protocol = RoborockDataProtocol(code) + except ValueError: + _LOGGER.debug("Ignoring unknown V1 data protocol code: %s", code) + continue + result[protocol] = value + + return result if result else None + + @dataclass class MapResponse: """Data structure for the V1 Map response.""" diff --git a/roborock/roborock_message.py b/roborock/roborock_message.py index fbe9becf..36d94640 100644 --- a/roborock/roborock_message.py +++ b/roborock/roborock_message.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from enum import StrEnum -from roborock import RoborockEnum +from roborock.data.code_mappings import RoborockEnum from roborock.util import get_next_int, get_timestamp diff --git a/tests/devices/rpc/test_v1_channel.py b/tests/devices/rpc/test_v1_channel.py index 293eb260..fac52666 100644 --- a/tests/devices/rpc/test_v1_channel.py +++ b/tests/devices/rpc/test_v1_channel.py @@ -250,8 +250,8 @@ async def test_v1_channel_subscribe_local_success( mock_local_session.assert_called_once_with(TEST_HOST) mock_local_channel.connect.assert_called_once() - # Verify local connection established and not mqtt - assert not mock_mqtt_channel.subscribers + # Verify mqtt is also established + assert mock_mqtt_channel.subscribers assert mock_local_channel.subscribers # Verify properties diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index 558d838c..8afc62cd 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -62,6 +62,7 @@ def device_fixture(channel: AsyncMock, rpc_channel: AsyncMock, mqtt_rpc_channel: rpc_channel, mqtt_rpc_channel, AsyncMock(), + Mock(), AsyncMock(), device_cache=DeviceCache(HOME_DATA.devices[0].duid, NoCache()), region=USER_DATA.region, diff --git a/tests/devices/traits/v1/fixtures.py b/tests/devices/traits/v1/fixtures.py index 08397493..bf42d151 100644 --- a/tests/devices/traits/v1/fixtures.py +++ b/tests/devices/traits/v1/fixtures.py @@ -1,7 +1,7 @@ """Fixtures for V1 trait tests.""" from copy import deepcopy -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest @@ -94,6 +94,7 @@ def device_fixture( mock_rpc_channel, mock_mqtt_rpc_channel, mock_map_rpc_channel, + Mock(), web_api_client, device_cache=device_cache, region=USER_DATA.region, diff --git a/tests/devices/traits/v1/test_status.py b/tests/devices/traits/v1/test_status.py index a308dbca..f6fe53e5 100644 --- a/tests/devices/traits/v1/test_status.py +++ b/tests/devices/traits/v1/test_status.py @@ -1,5 +1,6 @@ """Tests for the StatusTrait class.""" +import asyncio from typing import cast from unittest.mock import AsyncMock @@ -14,6 +15,7 @@ from roborock.devices.traits.v1.device_features import DeviceFeaturesTrait from roborock.devices.traits.v1.status import StatusTrait from roborock.exceptions import RoborockException +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data from tests.mock_data import STATUS @@ -122,3 +124,77 @@ def test_water_slide_mode_mapping() -> None: assert status_trait.water_mode_name == "low" status_trait.water_box_mode = 200 assert status_trait.water_mode_name == "off" + + +def test_update_from_dps(status_trait: StatusTrait) -> None: + """Test updating status from data protocol push message.""" + assert status_trait.battery is None + assert status_trait.state is None + + status_trait.update_from_dps( + { + RoborockDataProtocol.STATE: 5, + RoborockDataProtocol.BATTERY: 85, + RoborockDataProtocol.FAN_POWER: 102, + } + ) + + assert status_trait.state == 5 + assert status_trait.battery == 85 + assert status_trait.fan_power == 102 + + +def test_update_from_dps_partial(status_trait: StatusTrait) -> None: + """Test that partial updates only modify the specified fields.""" + status_trait.battery = 100 + status_trait.state = RoborockStateCode.charging + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 90, + } + ) + + assert status_trait.battery == 90 + assert status_trait.state == RoborockStateCode.charging # Unchanged + + +def test_update_listener(status_trait: StatusTrait) -> None: + """Test that update listeners receive notifications.""" + event = asyncio.Event() + unsubscribe = status_trait.add_update_listener(event.set) + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 88, + } + ) + + assert event.is_set() + event.clear() + + unsubscribe() + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 87, + } + ) + + assert not event.is_set() + + +def test_update_listener_ignores_unrelated(status_trait: StatusTrait) -> None: + """Test that update listeners are not notified for unrecognized data points.""" + event = asyncio.Event() + unsubscribe = status_trait.add_update_listener(event.set) + + # TASK_COMPLETE is not annotated with dps metadata on StatusV2 + status_trait.update_from_dps( + { + RoborockDataProtocol.TASK_COMPLETE: 1, + } + ) + + assert not event.is_set() + unsubscribe() diff --git a/tests/e2e/__snapshots__/test_device_manager.ambr b/tests/e2e/__snapshots__/test_device_manager.ambr index 17a045d4..0e8290ee 100644 --- a/tests/e2e/__snapshots__/test_device_manager.ambr +++ b/tests/e2e/__snapshots__/test_device_manager.ambr @@ -539,6 +539,14 @@ 00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h| 00000010 a6 a2 2b 00 01 00 10 6d b9 48 37 ed 43 59 7a 90 |..+....m.H7.CYz.| 00000020 ff 43 2f 0a 8f 81 44 e7 b6 b3 85 |.C/...D....| + [mqtt >] + 00000000 10 29 00 04 4d 51 54 54 05 c2 00 2d 00 00 00 00 |.)..MQTT...-....| + 00000010 08 31 39 36 34 38 66 39 34 00 10 32 33 34 36 37 |.19648f94..23467| + 00000020 38 65 61 38 35 34 66 31 39 39 65 |8ea854f199e| + [mqtt >] + 00000000 82 24 00 01 00 00 1e 72 72 2f 6d 2f 6f 2f 75 73 |.$.....rr/m/o/us| + 00000010 65 72 31 32 33 2f 31 39 36 34 38 66 39 34 2f 61 |er123/19648f94/a| + 00000020 62 63 31 32 33 00 |bc123.| [local >] 00000000 00 00 00 77 31 2e 30 00 00 23 8e 00 00 23 8f 68 |...w1.0..#...#.h| 00000010 a6 a2 2e 00 04 00 60 a9 a0 ac af 22 80 bb 11 b7 |......`...."....| diff --git a/tests/protocols/test_v1_protocol.py b/tests/protocols/test_v1_protocol.py index 1ec5026e..b5454057 100644 --- a/tests/protocols/test_v1_protocol.py +++ b/tests/protocols/test_v1_protocol.py @@ -17,9 +17,10 @@ RequestMessage, SecurityData, create_map_response_decoder, + decode_data_protocol_message, decode_rpc_response, ) -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data @@ -309,3 +310,88 @@ def test_invalid_unicode() -> None: ) with pytest.raises(RoborockException, match="Invalid V1 message payload"): decode_rpc_response(message) + + +def test_decode_data_protocol_message() -> None: + """Test decoding a V1 push message with data protocol updates.""" + payload = json.dumps({"t": 1652547161, "dps": {"121": 8, "122": 95}}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert result[RoborockDataProtocol.STATE] == 8 + assert result[RoborockDataProtocol.BATTERY] == 95 + + +def test_decode_data_protocol_message_all_status_fields() -> None: + """Test decoding a push message with all known status data protocol fields.""" + payload = json.dumps( + { + "t": 1652547161, + "dps": {"120": 0, "121": 5, "122": 100, "123": 102, "124": 204, "133": 1}, + } + ).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert result[RoborockDataProtocol.ERROR_CODE] == 0 + assert result[RoborockDataProtocol.STATE] == 5 + assert result[RoborockDataProtocol.BATTERY] == 100 + assert result[RoborockDataProtocol.FAN_POWER] == 102 + assert result[RoborockDataProtocol.WATER_BOX_MODE] == 204 + assert result[RoborockDataProtocol.CHARGE_STATUS] == 1 + + +def test_decode_data_protocol_message_unknown_codes() -> None: + """Test that unknown data protocol codes are ignored.""" + payload = json.dumps({"t": 1652547161, "dps": {"121": 8, "999": 42}}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert len(result) == 1 + assert result[RoborockDataProtocol.STATE] == 8 + + +def test_decode_data_protocol_message_empty_payload() -> None: + """Test decoding with empty payload returns None.""" + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=None, + ) + assert decode_data_protocol_message(message) is None + + +def test_decode_data_protocol_message_rpc_response() -> None: + """Test that an RPC response (code 102) produces None since the value is not a data protocol.""" + # This contains an RPC response (102) which has a JSON string value, not a data protocol code + payload = json.dumps( + { + "t": 1652547161, + "dps": {"102": '{"id":20001,"result":[{"state":8}]}'}, + } + ).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + # Code 102 is not in RoborockDataProtocol enum, so it should be ignored. + # The result should be None (no recognized data protocol codes). + assert decode_data_protocol_message(message) is None + + +def test_decode_data_protocol_message_no_dps() -> None: + """Test decoding message without dps returns None.""" + payload = json.dumps({"t": 1652547161}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + assert decode_data_protocol_message(message) is None From 9ea6b03b5c3b959d273c640137ffa5d2ef37d482 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 18:08:15 -0700 Subject: [PATCH 2/2] chore: Apply co-pilot feedback --- roborock/devices/rpc/v1_channel.py | 15 ++++++---- roborock/devices/traits/v1/__init__.py | 3 ++ roborock/devices/traits/v1/consumeable.py | 8 ++++++ roborock/protocols/v1_protocol.py | 7 ++++- tests/devices/rpc/test_v1_channel.py | 34 ++++++++++++++++++++++- tests/fixtures/channel_fixtures.py | 8 ++++++ 6 files changed, 67 insertions(+), 8 deletions(-) diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index 81c3466d..0e99e2fe 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -308,14 +308,17 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - # Always subscribe to MQTT to receive protocol updates (data points) + # Always attempt to subscribe to MQTT to receive protocol updates (data points) # even if we have a local connection. Protocol updates only come via cloud/MQTT. # Local connection is used for RPC commands, but push notifications come via MQTT. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - if self.is_local_connected: - self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") - else: - self._logger.debug("V1Channel connected via MQTT only") + try: + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + except RoborockException as err: + if not self.is_local_connected: + # Propagate error if both local and MQTT failed + self._logger.debug("MQTT connection also failed: %s", err) + raise + self._logger.debug("MQTT subscription failed, continuing with local-only connection: %s", err) def unsub() -> None: """Unsubscribe from all messages.""" diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index fa5890a0..a4b6c7d5 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -235,12 +235,15 @@ def _get_rpc_channel(self, trait: V1TraitMixin) -> V1RpcChannel: async def start(self) -> None: """Start the properties API and discover features.""" + if self._unsub: + return await self.discover_features() self._unsub = self._add_dps_listener(self._on_dps_update) def close(self) -> None: if self._unsub: self._unsub() + self._unsub = None def _on_dps_update(self, dps: dict[RoborockDataProtocol, Any]) -> None: """Handle incoming messages from the device. diff --git a/roborock/devices/traits/v1/consumeable.py b/roborock/devices/traits/v1/consumeable.py index 897af012..5d4b2cf0 100644 --- a/roborock/devices/traits/v1/consumeable.py +++ b/roborock/devices/traits/v1/consumeable.py @@ -4,6 +4,7 @@ periodically, such as filters, brushes, etc. """ +import logging from enum import StrEnum from typing import Any, Self @@ -18,6 +19,8 @@ "ConsumableTrait", ] +_LOGGER = logging.getLogger(__name__) + _DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable) @@ -48,6 +51,11 @@ class ConsumableTrait(Consumable, common.V1TraitMixin, TraitUpdateListener): command = RoborockCommand.GET_CONSUMABLE converter = common.DefaultConverter(Consumable) + def __init__(self) -> None: + """Initialize the consumable trait.""" + super().__init__() + TraitUpdateListener.__init__(self, logger=_LOGGER) + async def reset_consumable(self, consumable: ConsumableAttribute) -> None: """Reset a specific consumable attribute on the device.""" await self.rpc_channel.send_command(RoborockCommand.RESET_CONSUMABLE, params=[consumable.value]) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 14144793..8f039bb0 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -170,9 +170,14 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not (datapoints := _decode_dps_message(message)): + if not message.payload: return ResponseMessage(request_id=message.seq, data={}) + if (datapoints := _decode_dps_message(message)) is None: + raise RoborockException( + f"Invalid V1 message format: missing or invalid 'dps' in payload for {message.payload!r}" + ) + if not (data_point := datapoints.get(RoborockMessageProtocol.RPC_RESPONSE)): raise RoborockException( f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point" diff --git a/tests/devices/rpc/test_v1_channel.py b/tests/devices/rpc/test_v1_channel.py index fac52666..015ce37e 100644 --- a/tests/devices/rpc/test_v1_channel.py +++ b/tests/devices/rpc/test_v1_channel.py @@ -23,7 +23,7 @@ create_mqtt_encoder, ) from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data from tests.fixtures.channel_fixtures import FakeChannel @@ -580,3 +580,35 @@ async def test_v1_channel_send_map_command( # Verify the result is the data from our mocked decoder assert result == decompressed_map_data + + +async def test_v1_channel_add_dps_listener( + v1_channel: V1Channel, + mock_mqtt_channel: FakeChannel, +) -> None: + """Test that DPS listeners receive decoded protocol updates from MQTT.""" + mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE) + await v1_channel.subscribe(Mock()) + + # Create a mock listener for DPS updates + dps_listener = Mock() + unsub_dps = v1_channel.add_dps_listener(dps_listener) + + # Simulate an incoming MQTT message with data protocol payload. + dps_payload = json.dumps({"dps": {"121": 5}}).encode() + push_message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_REQUEST, + payload=dps_payload, + ) + mock_mqtt_channel.notify_subscribers(push_message) + + dps_listener.assert_called_once() + called_args = dps_listener.call_args[0][0] + assert called_args[RoborockDataProtocol.STATE] == 5 + + unsub_dps() + + # Verify unsubscribe works + dps_listener.reset_mock() + v1_channel._on_mqtt_message(push_message) + dps_listener.assert_not_called() diff --git a/tests/fixtures/channel_fixtures.py b/tests/fixtures/channel_fixtures.py index 1faae11c..90ace9fa 100644 --- a/tests/fixtures/channel_fixtures.py +++ b/tests/fixtures/channel_fixtures.py @@ -51,3 +51,11 @@ async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Calla """Simulate subscribing to messages.""" self.subscribers.append(callback) return lambda: self.subscribers.remove(callback) + + def notify_subscribers(self, message: RoborockMessage) -> None: + """Notify subscribers of a message. + + This can be used by tests to simulate the channel receiving a message. + """ + for subscriber in list(self.subscribers): + subscriber(message)