Skip to content

Commit 8786079

Browse files
committed
fix: add a decorator to mark traits as mqtt only
1 parent 8f25a60 commit 8786079

File tree

7 files changed

+63
-15
lines changed

7 files changed

+63
-15
lines changed

roborock/devices/device_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
149149
match device.pv:
150150
case DeviceVersion.V1:
151151
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
152-
trait = v1.create(product, channel.rpc_channel)
152+
trait = v1.create(product, channel.rpc_channel, channel.mqtt_rpc_channel)
153153
case DeviceVersion.A01:
154154
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
155155
trait = a01.create(product, channel)

roborock/devices/traits/v1/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Create traits for V1 devices."""
22

33
from dataclasses import dataclass, field, fields
4+
import logging
45

56
from roborock.containers import HomeData, HomeDataProduct
67
from roborock.devices.traits import Trait
@@ -13,6 +14,8 @@
1314
from .status import StatusTrait
1415
from .volume import SoundVolumeTrait
1516

17+
_LOGGER = logging.getLogger(__name__)
18+
1619
__all__ = [
1720
"create",
1821
"PropertiesApi",
@@ -40,10 +43,12 @@ class PropertiesApi(Trait):
4043

4144
# In the future optional fields can be added below based on supported features
4245

43-
def __init__(self, product: HomeDataProduct, rpc_channel: V1RpcChannel) -> None:
46+
def __init__(self, product: HomeDataProduct, rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel) -> None:
4447
"""Initialize the V1TraitProps with None values."""
4548
self.status = StatusTrait(product)
49+
#self.status._rpc_channel = rpc_channel
4650
self.maps = MapsTrait(self.status)
51+
#self.maps._rpc_channel = mqtt_rpc_channel
4752

4853
# This is a hack to allow setting the rpc_channel on all traits. This is
4954
# used so we can preserve the dataclass behavior when the values in the
@@ -53,9 +58,12 @@ def __init__(self, product: HomeDataProduct, rpc_channel: V1RpcChannel) -> None:
5358
if (trait := getattr(self, item.name, None)) is None:
5459
trait = item.type()
5560
setattr(self, item.name, trait)
56-
trait._rpc_channel = rpc_channel
61+
if hasattr(trait, "mqtt_rpc_channel"): # @common.mqtt_rpc_channel
62+
trait._rpc_channel = mqtt_rpc_channel
63+
else:
64+
trait._rpc_channel = rpc_channel
5765

5866

59-
def create(product: HomeDataProduct, rpc_channel: V1RpcChannel) -> PropertiesApi:
67+
def create(product: HomeDataProduct, rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel) -> PropertiesApi:
6068
"""Create traits for V1 devices."""
61-
return PropertiesApi(product, rpc_channel)
69+
return PropertiesApi(product, rpc_channel, mqtt_rpc_channel)

roborock/devices/traits/v1/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55

66
from abc import ABC
77
from dataclasses import dataclass, fields
8+
from functools import wraps
9+
import logging
810
from typing import ClassVar, Self
911

1012
from roborock.containers import RoborockBase
1113
from roborock.devices.v1_rpc_channel import V1RpcChannel
1214
from roborock.roborock_typing import RoborockCommand
1315

16+
_LOGGER = logging.getLogger(__name__)
17+
1418
V1ResponseData = dict | list | int | str
1519

1620

@@ -115,3 +119,16 @@ def _parse_response(cls, response: V1ResponseData) -> Self:
115119
raise ValueError(f"Unexpected response format: {response!r}")
116120
value_field = _get_value_field(cls)
117121
return cls(**{value_field: response})
122+
123+
124+
def mqtt_rpc_channel(cls):
125+
"""Decorator to mark a function as cloud only.
126+
127+
Normally a trait uses an adaptive rpc channel that can use either local
128+
or cloud communication depending on what is available. This will force
129+
the trait to always use the cloud rpc channel.
130+
"""
131+
def wrapper(*args, **kwargs):
132+
return cls(*args, **kwargs)
133+
cls.mqtt_rpc_channel = True # type: ignore[attr-defined]
134+
return wrapper

roborock/devices/traits/v1/maps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_LOGGER = logging.getLogger(__name__)
1717

1818

19+
@common.mqtt_rpc_channel
1920
class MapsTrait(MultiMapsList, common.V1TraitMixin):
2021
"""Trait for managing the maps of Roborock devices.
2122

tests/devices/test_v1_device.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,18 @@ def rpc_channel_fixture() -> AsyncMock:
3535
return AsyncMock()
3636

3737

38+
@pytest.fixture(autouse=True, name="mqtt_rpc_channel")
39+
def mqtt_rpc_channel_fixture() -> AsyncMock:
40+
"""Fixture to set up the channel for tests."""
41+
return AsyncMock()
42+
3843
@pytest.fixture(autouse=True, name="device")
39-
def device_fixture(channel: AsyncMock, rpc_channel: AsyncMock) -> RoborockDevice:
44+
def device_fixture(channel: AsyncMock, rpc_channel: AsyncMock, mqtt_rpc_channel: AsyncMock) -> RoborockDevice:
4045
"""Fixture to set up the device for tests."""
4146
return RoborockDevice(
4247
device_info=HOME_DATA.devices[0],
4348
channel=channel,
44-
trait=v1.create(HOME_DATA.products[0], rpc_channel),
49+
trait=v1.create(HOME_DATA.products[0], rpc_channel, mqtt_rpc_channel),
4550
)
4651

4752

tests/devices/traits/v1/fixtures.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,17 @@ def rpc_channel_fixture() -> AsyncMock:
2727
return AsyncMock()
2828

2929

30+
31+
@pytest.fixture(autouse=True, name="mock_mqtt_rpc_channel")
32+
def mqtt_rpc_channel_fixture() -> AsyncMock:
33+
"""Fixture to set up the channel for tests."""
34+
return AsyncMock()
35+
3036
@pytest.fixture(autouse=True, name="device")
31-
def device_fixture(channel: AsyncMock, mock_rpc_channel: AsyncMock) -> RoborockDevice:
37+
def device_fixture(channel: AsyncMock, mock_rpc_channel: AsyncMock, mock_mqtt_rpc_channel: AsyncMock) -> RoborockDevice:
3238
"""Fixture to set up the device for tests."""
3339
return RoborockDevice(
3440
device_info=HOME_DATA.devices[0],
3541
channel=channel,
36-
trait=v1.create(HOME_DATA.products[0], mock_rpc_channel),
42+
trait=v1.create(HOME_DATA.products[0], mock_rpc_channel, mock_mqtt_rpc_channel),
3743
)

tests/devices/traits/v1/test_maps.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,20 @@ def maps_trait(device: RoborockDevice) -> MapsTrait:
5858
async def test_refresh_maps_trait(
5959
maps_trait: MapsTrait,
6060
mock_rpc_channel: AsyncMock,
61+
mock_mqtt_rpc_channel: AsyncMock,
6162
status_trait: StatusTrait,
6263
) -> None:
6364
"""Test successfully getting multi maps list."""
6465
# Setup mock to return the sample multi maps list
6566
mock_rpc_channel.send_command.side_effect = [
6667
mock_data.STATUS, # Initial status fetch
68+
]
69+
mock_mqtt_rpc_channel.send_command.side_effect = [
6770
MULTI_MAP_LIST_DATA,
6871
]
6972
await status_trait.refresh()
73+
assert status_trait.current_map == 0
74+
7075
# Populating the status information gives us the current map
7176
# flag, but we have not loaded the rest of the information.
7277
assert maps_trait.current_map == 0
@@ -96,22 +101,27 @@ async def test_refresh_maps_trait(
96101
assert maps_trait.current_map_info.name == "Map 1"
97102

98103
# Verify the RPC call was made correctly
99-
assert mock_rpc_channel.send_command.call_count == 2
104+
assert mock_rpc_channel.send_command.call_count == 1
100105
mock_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_STATUS)
101-
mock_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_MULTI_MAPS_LIST)
106+
assert mock_mqtt_rpc_channel.send_command.call_count == 1
107+
mock_mqtt_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_MULTI_MAPS_LIST)
102108

103109

104110
async def test_set_current_map(
105111
status_trait: StatusTrait,
106112
maps_trait: MapsTrait,
107113
mock_rpc_channel: AsyncMock,
114+
mock_mqtt_rpc_channel: AsyncMock,
108115
) -> None:
109116
"""Test successfully setting the current map."""
117+
assert hasattr(maps_trait, "mqtt_rpc_channel")
110118
mock_rpc_channel.send_command.side_effect = [
111119
mock_data.STATUS, # Initial status fetch
120+
UPDATED_STATUS, # Response for refreshing status
121+
]
122+
mock_mqtt_rpc_channel.send_command.side_effect = [
112123
MULTI_MAP_LIST_DATA, # Response for LOAD_MULTI_MAP
113124
{}, # Response for setting the current map
114-
UPDATED_STATUS, # Response for refreshing status
115125
]
116126
await status_trait.refresh()
117127

@@ -139,7 +149,8 @@ async def test_set_current_map(
139149
# 2. GET_MULTI_MAPS_LIST to get the map list
140150
# 3. LOAD_MULTI_MAP to set the map
141151
# 4. GET_STATUS to refresh the current map in status
142-
assert mock_rpc_channel.send_command.call_count == 4
152+
assert mock_rpc_channel.send_command.call_count == 2
143153
mock_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_STATUS)
144-
mock_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_MULTI_MAPS_LIST)
145-
mock_rpc_channel.send_command.assert_any_call(RoborockCommand.LOAD_MULTI_MAP, params=[123])
154+
assert mock_mqtt_rpc_channel.send_command.call_count == 2
155+
mock_mqtt_rpc_channel.send_command.assert_any_call(RoborockCommand.GET_MULTI_MAPS_LIST)
156+
mock_mqtt_rpc_channel.send_command.assert_any_call(RoborockCommand.LOAD_MULTI_MAP, params=[123])

0 commit comments

Comments
 (0)