Skip to content

Commit 83e6e52

Browse files
committed
fix: handle different query types
1 parent 5193ef4 commit 83e6e52

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

roborock/devices/a01_channel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Thin wrapper around the MQTT channel for Roborock A01 devices."""
22

33
import asyncio
4+
import json
45
import logging
56
from typing import Any, overload
67

@@ -54,6 +55,13 @@ async def send_decoded_command(
5455
await mqtt_channel.publish(roborock_message)
5556
return {}
5657

58+
if isinstance(query_values, str):
59+
try:
60+
query_values = json.loads(query_values)
61+
except ValueError:
62+
_LOGGER.warning("Failed to parse query values: %s", query_values)
63+
return {}
64+
5765
# Merge any results together than contain the requested data. This
5866
# does not use a future since it needs to merge results across responses.
5967
# This could be simplified if we can assume there is a single response.

tests/test_a01_api.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import AsyncGenerator
44
from queue import Queue
55
from typing import Any
6-
from unittest.mock import patch
6+
from unittest.mock import AsyncMock, MagicMock, patch
77

88
import paho.mqtt.client as mqtt
99
import pytest
@@ -307,3 +307,38 @@ async def test_future_timeout(
307307
with patch("roborock.roborock_future.asyncio.timeout", side_effect=asyncio.TimeoutError):
308308
data = await connected_a01_mqtt_client.update_values([RoborockZeoProtocol.STATE])
309309
assert data.get(RoborockZeoProtocol.STATE) is None
310+
311+
312+
async def test_send_decoded_command_handles_stringified_query() -> None:
313+
"""Test that send_decoded_command handles ID_QUERY as a stringified list."""
314+
from roborock.devices.a01_channel import send_decoded_command
315+
from roborock.devices.mqtt_channel import MqttChannel
316+
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessage, RoborockMessageProtocol
317+
318+
channel = MagicMock(spec=MqttChannel)
319+
channel.publish = AsyncMock()
320+
321+
captured_callback = None
322+
323+
async def mock_subscribe(callback):
324+
nonlocal captured_callback
325+
captured_callback = callback
326+
return lambda: None
327+
328+
channel.subscribe = AsyncMock(side_effect=mock_subscribe)
329+
330+
protocol_id = 101
331+
params = {RoborockDyadDataProtocol.ID_QUERY: str([protocol_id])}
332+
333+
task = asyncio.create_task(send_decoded_command(channel, params))
334+
await asyncio.sleep(0)
335+
336+
response_data = {"dps": {str(protocol_id): 123}}
337+
payload = pad(json.dumps(response_data).encode("utf-8"), AES.block_size)
338+
message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=payload)
339+
340+
if captured_callback:
341+
captured_callback(message)
342+
343+
result = await task
344+
assert result == {protocol_id: 123}

0 commit comments

Comments
 (0)