From 27c91b856af83538e9f3a1d09a78bc3689343171 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 17:31:16 -0700 Subject: [PATCH 1/7] feat: Implement direct device trait updates from data protocol messages using `dps` metadata and add corresponding update listeners. This uses the same dps converter patern used by q10, but does not share code explicitly. --- roborock/devices/device.py | 4 +- roborock/devices/device_manager.py | 1 + roborock/devices/rpc/v1_channel.py | 34 +++++-- roborock/devices/traits/v1/__init__.py | 28 +++++- roborock/devices/traits/v1/common.py | 76 +++++++++++++++- roborock/devices/traits/v1/consumeable.py | 18 +++- roborock/devices/traits/v1/status.py | 20 ++++- roborock/protocols/v1_protocol.py | 63 ++++++++++--- tests/devices/rpc/test_v1_channel.py | 4 +- tests/devices/test_v1_device.py | 1 + tests/devices/traits/v1/fixtures.py | 3 +- tests/devices/traits/v1/test_status.py | 76 ++++++++++++++++ .../__snapshots__/test_device_manager.ambr | 8 ++ tests/protocols/test_v1_protocol.py | 88 ++++++++++++++++++- 14 files changed, 395 insertions(+), 29 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 29f1fd28..bf020814 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -199,7 +199,7 @@ async def connect(self) -> None: unsub = await self._channel.subscribe(self._on_message) try: if self.v1_properties is not None: - await self.v1_properties.discover_features() + await self.v1_properties.start() elif self.b01_q10_properties is not None: await self.b01_q10_properties.start() except RoborockException: @@ -216,6 +216,8 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self.v1_properties is not None: + self.v1_properties.close() if self.b01_q10_properties is not None: await self.b01_q10_properties.close() if self._unsub: diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 0be98ea1..701370ae 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -236,6 +236,7 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat channel.rpc_channel, channel.mqtt_rpc_channel, channel.map_rpc_channel, + channel.add_dps_listener, web_api, device_cache=device_cache, map_parser_config=map_parser_config, diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index d1b4ee24..81c3466d 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from typing import Any, TypeVar +from roborock.callbacks import CallbackList from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData from roborock.devices.cache import DeviceCache from roborock.devices.transport.channel import Channel @@ -30,9 +31,10 @@ V1RpcChannel, create_map_response_decoder, create_security_data, + decode_data_protocol_message, decode_rpc_response, ) -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from roborock.util import RoborockLoggerAdapter @@ -188,6 +190,7 @@ def __init__( self._device_cache = device_cache self._reconnect_task: asyncio.Task[None] | None = None self._last_network_info_refresh: datetime.datetime | None = None + self._dps_listeners = CallbackList[dict[RoborockDataProtocol, Any]]() @property def is_connected(self) -> bool: @@ -305,12 +308,14 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - if not self.is_local_connected: - # We were not able to connect locally, so fallback to MQTT and at least - # establish that connection explicitly. If this fails then raise an - # error and let the caller know we failed to subscribe. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - self._logger.debug("V1Channel connected to device via MQTT") + # Always subscribe to MQTT to receive protocol updates (data points) + # even if we have a local connection. Protocol updates only come via cloud/MQTT. + # Local connection is used for RPC commands, but push notifications come via MQTT. + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self.is_local_connected: + self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") + else: + self._logger.debug("V1Channel connected via MQTT only") def unsub() -> None: """Unsubscribe from all messages.""" @@ -328,6 +333,16 @@ def unsub() -> None: self._callback = callback return unsub + def add_dps_listener(self, listener: Callable[[dict[RoborockDataProtocol, Any]], None]) -> Callable[[], None]: + """Add a listener for DPS updates. + + This will attach a listener to the existing subscription, invoking + the listener whenever new DPS values arrive from the subscription. + This will only work if a subscription has already been setup, which is + handled by the device setup. + """ + return self._dps_listeners.add_callback(listener) + async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInfo: """Retrieve networking information for the device. @@ -428,6 +443,11 @@ def _on_mqtt_message(self, message: RoborockMessage) -> None: self._logger.debug("V1Channel received MQTT message: %s", message) if self._callback: self._callback(message) + try: + if datapoints := decode_data_protocol_message(message): + self._dps_listeners(datapoints) + except RoborockException as e: + self._logger.debug("Error decoding data protocol message: %s", e) def _on_local_message(self, message: RoborockMessage) -> None: """Handle incoming local messages.""" diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index a55280f2..ebe32125 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -53,6 +53,7 @@ """ import logging +from collections.abc import Callable from dataclasses import dataclass, field, fields from typing import Any, get_args @@ -60,8 +61,10 @@ from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode from roborock.devices.cache import DeviceCache from roborock.devices.traits import Trait +from roborock.exceptions import RoborockException from roborock.map.map_parser import MapParserConfig -from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.protocols.v1_protocol import V1RpcChannel, decode_data_protocol_message +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage from roborock.web_api import UserWebApiClient from . import ( @@ -176,6 +179,7 @@ def __init__( rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel, map_rpc_channel: V1RpcChannel, + add_dps_listener: Callable[[Callable[[dict[RoborockDataProtocol, Any]], None]], Callable[[], None]], web_api: UserWebApiClient, device_cache: DeviceCache, map_parser_config: MapParserConfig | None = None, @@ -189,6 +193,8 @@ def __init__( self._web_api = web_api self._device_cache = device_cache self._region = region + self._unsub: Callable[[], None] | None = None + self._add_dps_listener = add_dps_listener self.device_features = DeviceFeaturesTrait(product, self._device_cache) self.status = StatusTrait(self.device_features, region=self._region) @@ -227,6 +233,24 @@ def _get_rpc_channel(self, trait: V1TraitMixin) -> V1RpcChannel: else: return self._rpc_channel + async def start(self) -> None: + """Start the properties API and discover features.""" + await self.discover_features() + self._unsub = self._add_dps_listener(self._on_dps_update) + + def close(self) -> None: + if self._unsub: + self._unsub() + + def _on_dps_update(self, dps: dict[RoborockDataProtocol, Any]) -> None: + """Handle incoming messages from the device. + + This will notify all traits of the new values. + """ + _LOGGER.debug("Received message from device: %s", dps) + self.status.update_from_dps(dps) + self.consumables.update_from_dps(dps) + async def discover_features(self) -> None: """Populate any supported traits that were not initialized in __init__.""" _LOGGER.debug("Starting optional trait discovery") @@ -330,6 +354,7 @@ def create( rpc_channel: V1RpcChannel, mqtt_rpc_channel: V1RpcChannel, map_rpc_channel: V1RpcChannel, + add_dps_listener: Callable[[Callable[[dict[RoborockDataProtocol, Any]], None]], Callable[[], None]], web_api: UserWebApiClient, device_cache: DeviceCache, map_parser_config: MapParserConfig | None = None, @@ -343,6 +368,7 @@ def create( rpc_channel, mqtt_rpc_channel, map_rpc_channel, + add_dps_listener, web_api, device_cache, map_parser_config, diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 431cd075..22792d98 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -5,12 +5,15 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import fields -from typing import ClassVar +from typing import Any, ClassVar +from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.exceptions import RoborockParsingException from roborock.protocols.v1_protocol import V1RpcChannel +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand _LOGGER = logging.getLogger(__name__) @@ -182,3 +185,74 @@ def wrapper(*args, **kwargs): cls.map_rpc_channel = True # type: ignore[attr-defined] return wrapper + + +# TODO(allenporter): Merge with roborock.devices.traits.b01.q10.common.TraitUpdateListener +class TraitUpdateListener(ABC): + """Trait update listener. + + This is a base class for traits to support notifying listeners when they + have been updated. Clients may register callbacks to be notified when the + trait has been updated. When the listener callback is invoked, the client + should read the trait's properties to get the updated values. + """ + + def __init__(self, logger: logging.Logger) -> None: + """Initialize the trait update listener.""" + self._update_callbacks: CallbackList[None] = CallbackList(logger=logger) + + def add_update_listener(self, callback: Callable[[], None]) -> Callable[[], None]: + """Register a callback when the trait has been updated. + + Returns a callable to remove the listener. + """ + # We wrap the callback to ignore the value passed to it. + return self._update_callbacks.add_callback(lambda _: callback()) + + def _notify_update(self) -> None: + """Notify all update listeners.""" + self._update_callbacks(None) + + +class DpsDataConverter: + """Utility to handle the transformation and merging of DPS data into models. + + This class pre-calculates the mapping between Data Point IDs and dataclass fields + to optimize repeated updates from device streams. + """ + + def __init__(self, dps_type_map: dict[RoborockDataProtocol, type], dps_field_map: dict[RoborockDataProtocol, str]): + """Initialize the converter for a specific RoborockBase-derived class.""" + self._dps_type_map = dps_type_map + self._dps_field_map = dps_field_map + + @classmethod + def from_dataclass(cls, dataclass_type: type[RoborockBase]): + """Initialize the converter for a specific RoborockBase-derived class.""" + dps_type_map: dict[RoborockDataProtocol, type] = {} + dps_field_map: dict[RoborockDataProtocol, str] = {} + for field_obj in fields(dataclass_type): + if field_obj.metadata and "dps" in field_obj.metadata: + dps_id = field_obj.metadata["dps"] + dps_type_map[dps_id] = field_obj.type + dps_field_map[dps_id] = field_obj.name + return cls(dps_type_map, dps_field_map) + + def update_from_dps(self, target: RoborockBase, decoded_dps: dict[RoborockDataProtocol, Any]) -> bool: + """Convert and merge raw DPS data into the target object. + + Uses the pre-calculated type mapping to ensure values are converted to the + correct Python types before being updated on the target. + + Args: + target: The target object to update. + decoded_dps: The decoded DPS data to convert. + + Returns: + True if any values were updated, False otherwise. + """ + conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) + for dps_id, value in conversions.items(): + field_name = self._dps_field_map[dps_id] + setattr(target, field_name, value) + return bool(conversions) diff --git a/roborock/devices/traits/v1/consumeable.py b/roborock/devices/traits/v1/consumeable.py index 9c72ed68..0d0b9ef3 100644 --- a/roborock/devices/traits/v1/consumeable.py +++ b/roborock/devices/traits/v1/consumeable.py @@ -5,16 +5,21 @@ """ from enum import StrEnum -from typing import Self +from typing import Any, Self from roborock.data import Consumable from roborock.devices.traits.v1 import common +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand +from .common import TraitUpdateListener + __all__ = [ "ConsumableTrait", ] +_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable) + class ConsumableAttribute(StrEnum): """Enum for consumable attributes.""" @@ -35,7 +40,7 @@ def from_str(cls, value: str) -> Self: raise ValueError(f"Unknown ConsumableAttribute: {value}") -class ConsumableTrait(Consumable, common.V1TraitMixin): +class ConsumableTrait(Consumable, common.V1TraitMixin, TraitUpdateListener): """Trait for managing consumable attributes on Roborock devices. After the first refresh, you can tell what consumables are supported by @@ -49,3 +54,12 @@ async def reset_consumable(self, consumable: ConsumableAttribute) -> None: """Reset a specific consumable attribute on the device.""" await self.rpc_channel.send_command(RoborockCommand.RESET_CONSUMABLE, params=[consumable.value]) await self.refresh() + + def update_from_dps(self, decoded_dps: dict[RoborockDataProtocol, Any]) -> None: + """Update the trait from data protocol push message data. + + This handles unsolicited status updates pushed by the device + via RoborockDataProtocol codes (e.g. STATE=121, BATTERY=122). + """ + if _DPS_CONVERTER.update_from_dps(self, decoded_dps): + self._notify_update() diff --git a/roborock/devices/traits/v1/status.py b/roborock/devices/traits/v1/status.py index 82371c15..84cdcb9b 100644 --- a/roborock/devices/traits/v1/status.py +++ b/roborock/devices/traits/v1/status.py @@ -1,4 +1,6 @@ +import logging from functools import cached_property +from typing import Any from roborock import ( CleanRoutes, @@ -10,13 +12,19 @@ get_water_mode_mapping, get_water_modes, ) +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from . import common +from .common import TraitUpdateListener from .device_features import DeviceFeaturesTrait +_LOGGER = logging.getLogger(__name__) -class StatusTrait(StatusV2, common.V1TraitMixin): +_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(StatusV2) + + +class StatusTrait(StatusV2, common.V1TraitMixin, TraitUpdateListener): """Trait for managing the status of Roborock devices. The StatusTrait gives you the access to the state of a Roborock vacuum. @@ -47,6 +55,7 @@ class StatusTrait(StatusV2, common.V1TraitMixin): def __init__(self, device_feature_trait: DeviceFeaturesTrait, region: str | None = None) -> None: """Initialize the StatusTrait.""" super().__init__() + TraitUpdateListener.__init__(self, logger=_LOGGER) self._device_features_trait = device_feature_trait self._region = region @@ -91,3 +100,12 @@ def mop_route_name(self) -> str | None: if self.mop_mode is None: return None return self.mop_route_mapping.get(self.mop_mode) + + def update_from_dps(self, decoded_dps: dict[RoborockDataProtocol, Any]) -> None: + """Update the trait from data protocol push message data. + + This handles unsolicited status updates pushed by the device + via RoborockDataProtocol codes (e.g. STATE=121, BATTERY=122). + """ + if _DPS_CONVERTER.update_from_dps(self, decoded_dps): + self._notify_update() diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 355043c5..14144793 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -15,7 +15,7 @@ from roborock.data import RoborockBase, RRiot from roborock.exceptions import RoborockException, RoborockInvalidStatus, RoborockUnsupportedFeature from roborock.protocol import Utils -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from roborock.util import get_next_int, get_timestamp @@ -24,6 +24,7 @@ __all__ = [ "SecurityData", "create_security_data", + "decode_data_protocol_message", "decode_rpc_response", "V1RpcChannel", ] @@ -139,6 +140,28 @@ class ResponseMessage: """The API error message of the response if any.""" +def _decode_dps_message(message: RoborockMessage) -> dict[int, Any] | None: + """Decode a V1 push message containing data protocol updates.""" + if not message.payload: + return None + try: + payload = json.loads(message.payload.decode()) + except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: + raise RoborockException(f"Invalid V1 message payload: {e} for {message.payload!r}") from e + + datapoints = payload.get("dps") + if not isinstance(datapoints, dict): + return None + result: dict[int, Any] = {} + for key, value in datapoints.items(): + try: + code = int(key) + except (ValueError, TypeError): + continue + result[code] = value + return result if result else None + + def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: """Decode a V1 RPC_RESPONSE message. @@ -147,19 +170,10 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not message.payload: + if not (datapoints := _decode_dps_message(message)): return ResponseMessage(request_id=message.seq, data={}) - try: - payload = json.loads(message.payload.decode()) - except (json.JSONDecodeError, TypeError, UnicodeDecodeError) as e: - raise RoborockException(f"Invalid V1 message payload: {e} for {message.payload!r}") from e - - _LOGGER.debug("Decoded V1 message payload: %s", payload) - datapoints = payload.get("dps", {}) - if not isinstance(datapoints, dict): - raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}") - if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))): + if not (data_point := datapoints.get(RoborockMessageProtocol.RPC_RESPONSE)): raise RoborockException( f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point" ) @@ -206,6 +220,31 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: return ResponseMessage(request_id=request_id, data=result, api_error=api_error) +def decode_data_protocol_message(message: RoborockMessage) -> dict[RoborockDataProtocol, Any] | None: + """Decode a V1 push message containing data protocol updates. + + V1 devices push unsolicited status updates containing data points keyed + by RoborockDataProtocol codes (e.g., 121=STATE, 122=BATTERY). This function + extracts those data points from the message payload. + + Returns a dict mapping RoborockDataProtocol to values, or None if the + message does not contain any recognized data protocol updates. + """ + if not (datapoints := _decode_dps_message(message)): + return None + + result: dict[RoborockDataProtocol, Any] = {} + for code, value in datapoints.items(): + try: + protocol = RoborockDataProtocol(code) + except ValueError: + _LOGGER.debug("Ignoring unknown V1 data protocol code: %s", code) + continue + result[protocol] = value + + return result if result else None + + @dataclass class MapResponse: """Data structure for the V1 Map response.""" diff --git a/tests/devices/rpc/test_v1_channel.py b/tests/devices/rpc/test_v1_channel.py index 293eb260..fac52666 100644 --- a/tests/devices/rpc/test_v1_channel.py +++ b/tests/devices/rpc/test_v1_channel.py @@ -250,8 +250,8 @@ async def test_v1_channel_subscribe_local_success( mock_local_session.assert_called_once_with(TEST_HOST) mock_local_channel.connect.assert_called_once() - # Verify local connection established and not mqtt - assert not mock_mqtt_channel.subscribers + # Verify mqtt is also established + assert mock_mqtt_channel.subscribers assert mock_local_channel.subscribers # Verify properties diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index 558d838c..8afc62cd 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -62,6 +62,7 @@ def device_fixture(channel: AsyncMock, rpc_channel: AsyncMock, mqtt_rpc_channel: rpc_channel, mqtt_rpc_channel, AsyncMock(), + Mock(), AsyncMock(), device_cache=DeviceCache(HOME_DATA.devices[0].duid, NoCache()), region=USER_DATA.region, diff --git a/tests/devices/traits/v1/fixtures.py b/tests/devices/traits/v1/fixtures.py index 08397493..bf42d151 100644 --- a/tests/devices/traits/v1/fixtures.py +++ b/tests/devices/traits/v1/fixtures.py @@ -1,7 +1,7 @@ """Fixtures for V1 trait tests.""" from copy import deepcopy -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest @@ -94,6 +94,7 @@ def device_fixture( mock_rpc_channel, mock_mqtt_rpc_channel, mock_map_rpc_channel, + Mock(), web_api_client, device_cache=device_cache, region=USER_DATA.region, diff --git a/tests/devices/traits/v1/test_status.py b/tests/devices/traits/v1/test_status.py index b0f32ce4..1878d2fd 100644 --- a/tests/devices/traits/v1/test_status.py +++ b/tests/devices/traits/v1/test_status.py @@ -1,5 +1,6 @@ """Tests for the StatusTrait class.""" +import asyncio from typing import cast from unittest.mock import AsyncMock @@ -14,6 +15,7 @@ from roborock.devices.traits.v1.device_features import DeviceFeaturesTrait from roborock.devices.traits.v1.status import StatusTrait from roborock.exceptions import RoborockException, RoborockParsingException +from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data from tests.mock_data import STATUS @@ -122,3 +124,77 @@ def test_water_slide_mode_mapping() -> None: assert status_trait.water_mode_name == "low" status_trait.water_box_mode = 200 assert status_trait.water_mode_name == "off" + + +def test_update_from_dps(status_trait: StatusTrait) -> None: + """Test updating status from data protocol push message.""" + assert status_trait.battery is None + assert status_trait.state is None + + status_trait.update_from_dps( + { + RoborockDataProtocol.STATE: 5, + RoborockDataProtocol.BATTERY: 85, + RoborockDataProtocol.FAN_POWER: 102, + } + ) + + assert status_trait.state == 5 + assert status_trait.battery == 85 + assert status_trait.fan_power == 102 + + +def test_update_from_dps_partial(status_trait: StatusTrait) -> None: + """Test that partial updates only modify the specified fields.""" + status_trait.battery = 100 + status_trait.state = RoborockStateCode.charging + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 90, + } + ) + + assert status_trait.battery == 90 + assert status_trait.state == RoborockStateCode.charging # Unchanged + + +def test_update_listener(status_trait: StatusTrait) -> None: + """Test that update listeners receive notifications.""" + event = asyncio.Event() + unsubscribe = status_trait.add_update_listener(event.set) + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 88, + } + ) + + assert event.is_set() + event.clear() + + unsubscribe() + + status_trait.update_from_dps( + { + RoborockDataProtocol.BATTERY: 87, + } + ) + + assert not event.is_set() + + +def test_update_listener_ignores_unrelated(status_trait: StatusTrait) -> None: + """Test that update listeners are not notified for unrecognized data points.""" + event = asyncio.Event() + unsubscribe = status_trait.add_update_listener(event.set) + + # TASK_COMPLETE is not annotated with dps metadata on StatusV2 + status_trait.update_from_dps( + { + RoborockDataProtocol.TASK_COMPLETE: 1, + } + ) + + assert not event.is_set() + unsubscribe() diff --git a/tests/e2e/__snapshots__/test_device_manager.ambr b/tests/e2e/__snapshots__/test_device_manager.ambr index 17a045d4..0e8290ee 100644 --- a/tests/e2e/__snapshots__/test_device_manager.ambr +++ b/tests/e2e/__snapshots__/test_device_manager.ambr @@ -539,6 +539,14 @@ 00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h| 00000010 a6 a2 2b 00 01 00 10 6d b9 48 37 ed 43 59 7a 90 |..+....m.H7.CYz.| 00000020 ff 43 2f 0a 8f 81 44 e7 b6 b3 85 |.C/...D....| + [mqtt >] + 00000000 10 29 00 04 4d 51 54 54 05 c2 00 2d 00 00 00 00 |.)..MQTT...-....| + 00000010 08 31 39 36 34 38 66 39 34 00 10 32 33 34 36 37 |.19648f94..23467| + 00000020 38 65 61 38 35 34 66 31 39 39 65 |8ea854f199e| + [mqtt >] + 00000000 82 24 00 01 00 00 1e 72 72 2f 6d 2f 6f 2f 75 73 |.$.....rr/m/o/us| + 00000010 65 72 31 32 33 2f 31 39 36 34 38 66 39 34 2f 61 |er123/19648f94/a| + 00000020 62 63 31 32 33 00 |bc123.| [local >] 00000000 00 00 00 77 31 2e 30 00 00 23 8e 00 00 23 8f 68 |...w1.0..#...#.h| 00000010 a6 a2 2e 00 04 00 60 a9 a0 ac af 22 80 bb 11 b7 |......`...."....| diff --git a/tests/protocols/test_v1_protocol.py b/tests/protocols/test_v1_protocol.py index 1ec5026e..b5454057 100644 --- a/tests/protocols/test_v1_protocol.py +++ b/tests/protocols/test_v1_protocol.py @@ -17,9 +17,10 @@ RequestMessage, SecurityData, create_map_response_decoder, + decode_data_protocol_message, decode_rpc_response, ) -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data @@ -309,3 +310,88 @@ def test_invalid_unicode() -> None: ) with pytest.raises(RoborockException, match="Invalid V1 message payload"): decode_rpc_response(message) + + +def test_decode_data_protocol_message() -> None: + """Test decoding a V1 push message with data protocol updates.""" + payload = json.dumps({"t": 1652547161, "dps": {"121": 8, "122": 95}}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert result[RoborockDataProtocol.STATE] == 8 + assert result[RoborockDataProtocol.BATTERY] == 95 + + +def test_decode_data_protocol_message_all_status_fields() -> None: + """Test decoding a push message with all known status data protocol fields.""" + payload = json.dumps( + { + "t": 1652547161, + "dps": {"120": 0, "121": 5, "122": 100, "123": 102, "124": 204, "133": 1}, + } + ).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert result[RoborockDataProtocol.ERROR_CODE] == 0 + assert result[RoborockDataProtocol.STATE] == 5 + assert result[RoborockDataProtocol.BATTERY] == 100 + assert result[RoborockDataProtocol.FAN_POWER] == 102 + assert result[RoborockDataProtocol.WATER_BOX_MODE] == 204 + assert result[RoborockDataProtocol.CHARGE_STATUS] == 1 + + +def test_decode_data_protocol_message_unknown_codes() -> None: + """Test that unknown data protocol codes are ignored.""" + payload = json.dumps({"t": 1652547161, "dps": {"121": 8, "999": 42}}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + result = decode_data_protocol_message(message) + assert result is not None + assert len(result) == 1 + assert result[RoborockDataProtocol.STATE] == 8 + + +def test_decode_data_protocol_message_empty_payload() -> None: + """Test decoding with empty payload returns None.""" + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=None, + ) + assert decode_data_protocol_message(message) is None + + +def test_decode_data_protocol_message_rpc_response() -> None: + """Test that an RPC response (code 102) produces None since the value is not a data protocol.""" + # This contains an RPC response (102) which has a JSON string value, not a data protocol code + payload = json.dumps( + { + "t": 1652547161, + "dps": {"102": '{"id":20001,"result":[{"state":8}]}'}, + } + ).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + # Code 102 is not in RoborockDataProtocol enum, so it should be ignored. + # The result should be None (no recognized data protocol codes). + assert decode_data_protocol_message(message) is None + + +def test_decode_data_protocol_message_no_dps() -> None: + """Test decoding message without dps returns None.""" + payload = json.dumps({"t": 1652547161}).encode() + message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_RESPONSE, + payload=payload, + ) + assert decode_data_protocol_message(message) is None From 46a66607dd22e60d8d3dbd910893f13f93e9a78e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 18:08:15 -0700 Subject: [PATCH 2/7] chore: Apply co-pilot feedback --- roborock/devices/rpc/v1_channel.py | 15 ++++++---- roborock/devices/traits/v1/__init__.py | 3 ++ roborock/devices/traits/v1/consumeable.py | 8 ++++++ roborock/protocols/v1_protocol.py | 7 ++++- tests/devices/rpc/test_v1_channel.py | 34 ++++++++++++++++++++++- tests/fixtures/channel_fixtures.py | 8 ++++++ 6 files changed, 67 insertions(+), 8 deletions(-) diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index 81c3466d..0e99e2fe 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -308,14 +308,17 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - # Always subscribe to MQTT to receive protocol updates (data points) + # Always attempt to subscribe to MQTT to receive protocol updates (data points) # even if we have a local connection. Protocol updates only come via cloud/MQTT. # Local connection is used for RPC commands, but push notifications come via MQTT. - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - if self.is_local_connected: - self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)") - else: - self._logger.debug("V1Channel connected via MQTT only") + try: + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + except RoborockException as err: + if not self.is_local_connected: + # Propagate error if both local and MQTT failed + self._logger.debug("MQTT connection also failed: %s", err) + raise + self._logger.debug("MQTT subscription failed, continuing with local-only connection: %s", err) def unsub() -> None: """Unsubscribe from all messages.""" diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index ebe32125..b430bfa9 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -235,12 +235,15 @@ def _get_rpc_channel(self, trait: V1TraitMixin) -> V1RpcChannel: async def start(self) -> None: """Start the properties API and discover features.""" + if self._unsub: + return await self.discover_features() self._unsub = self._add_dps_listener(self._on_dps_update) def close(self) -> None: if self._unsub: self._unsub() + self._unsub = None def _on_dps_update(self, dps: dict[RoborockDataProtocol, Any]) -> None: """Handle incoming messages from the device. diff --git a/roborock/devices/traits/v1/consumeable.py b/roborock/devices/traits/v1/consumeable.py index 0d0b9ef3..5e213716 100644 --- a/roborock/devices/traits/v1/consumeable.py +++ b/roborock/devices/traits/v1/consumeable.py @@ -4,6 +4,7 @@ periodically, such as filters, brushes, etc. """ +import logging from enum import StrEnum from typing import Any, Self @@ -18,6 +19,8 @@ "ConsumableTrait", ] +_LOGGER = logging.getLogger(__name__) + _DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable) @@ -50,6 +53,11 @@ class ConsumableTrait(Consumable, common.V1TraitMixin, TraitUpdateListener): command = RoborockCommand.GET_CONSUMABLE converter = common.DefaultConverter(Consumable) + def __init__(self) -> None: + """Initialize the consumable trait.""" + super().__init__() + TraitUpdateListener.__init__(self, logger=_LOGGER) + async def reset_consumable(self, consumable: ConsumableAttribute) -> None: """Reset a specific consumable attribute on the device.""" await self.rpc_channel.send_command(RoborockCommand.RESET_CONSUMABLE, params=[consumable.value]) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 14144793..8f039bb0 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -170,9 +170,14 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not (datapoints := _decode_dps_message(message)): + if not message.payload: return ResponseMessage(request_id=message.seq, data={}) + if (datapoints := _decode_dps_message(message)) is None: + raise RoborockException( + f"Invalid V1 message format: missing or invalid 'dps' in payload for {message.payload!r}" + ) + if not (data_point := datapoints.get(RoborockMessageProtocol.RPC_RESPONSE)): raise RoborockException( f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point" diff --git a/tests/devices/rpc/test_v1_channel.py b/tests/devices/rpc/test_v1_channel.py index fac52666..015ce37e 100644 --- a/tests/devices/rpc/test_v1_channel.py +++ b/tests/devices/rpc/test_v1_channel.py @@ -23,7 +23,7 @@ create_mqtt_encoder, ) from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand from tests import mock_data from tests.fixtures.channel_fixtures import FakeChannel @@ -580,3 +580,35 @@ async def test_v1_channel_send_map_command( # Verify the result is the data from our mocked decoder assert result == decompressed_map_data + + +async def test_v1_channel_add_dps_listener( + v1_channel: V1Channel, + mock_mqtt_channel: FakeChannel, +) -> None: + """Test that DPS listeners receive decoded protocol updates from MQTT.""" + mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE) + await v1_channel.subscribe(Mock()) + + # Create a mock listener for DPS updates + dps_listener = Mock() + unsub_dps = v1_channel.add_dps_listener(dps_listener) + + # Simulate an incoming MQTT message with data protocol payload. + dps_payload = json.dumps({"dps": {"121": 5}}).encode() + push_message = RoborockMessage( + protocol=RoborockMessageProtocol.GENERAL_REQUEST, + payload=dps_payload, + ) + mock_mqtt_channel.notify_subscribers(push_message) + + dps_listener.assert_called_once() + called_args = dps_listener.call_args[0][0] + assert called_args[RoborockDataProtocol.STATE] == 5 + + unsub_dps() + + # Verify unsubscribe works + dps_listener.reset_mock() + v1_channel._on_mqtt_message(push_message) + dps_listener.assert_not_called() diff --git a/tests/fixtures/channel_fixtures.py b/tests/fixtures/channel_fixtures.py index 1faae11c..90ace9fa 100644 --- a/tests/fixtures/channel_fixtures.py +++ b/tests/fixtures/channel_fixtures.py @@ -51,3 +51,11 @@ async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Calla """Simulate subscribing to messages.""" self.subscribers.append(callback) return lambda: self.subscribers.remove(callback) + + def notify_subscribers(self, message: RoborockMessage) -> None: + """Notify subscribers of a message. + + This can be used by tests to simulate the channel receiving a message. + """ + for subscriber in list(self.subscribers): + subscriber(message) From 564a479427e86d5efea41bf096284111428d1e43 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 5 Apr 2026 08:45:49 -0700 Subject: [PATCH 3/7] refactor: simplify device feature support checks by using DPS IDs instead of schema codes --- roborock/data/v1/v1_containers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/roborock/data/v1/v1_containers.py b/roborock/data/v1/v1_containers.py index 29fb9c3a..6d8a13f1 100644 --- a/roborock/data/v1/v1_containers.py +++ b/roborock/data/v1/v1_containers.py @@ -98,13 +98,14 @@ class FieldNameBase(StrEnum): class StatusField(FieldNameBase): - """An enum that represents a field in the `Status` class. + """An enum that represents a field in the `StatusV2` class. This is used with `roborock.devices.traits.v1.status.DeviceFeaturesTrait` to understand if a feature is supported by the device using `is_field_supported`. - The enum values are names of fields in the `Status` class. Each field is annotated - with a metadata value to determine if the field is supported by the device. + The enum values are names of fields in the `StatusV2` class. Each field is + annotated with `dps` metadata to map the field to a `RoborockDataProtocol` + value used to check support against the product schema. """ STATE = "state" @@ -629,8 +630,9 @@ class ConsumableField(FieldNameBase): This is used with `roborock.devices.traits.v1.status.DeviceFeaturesTrait` to understand if a feature is supported by the device using `is_field_supported`. - The enum values are names of fields in the `Consumable` class. Each field is annotated - with a metadata value to determine if the field is supported by the device. + The enum values are names of fields in the `Consumable` class. Each field is + annotated with `dps` metadata to map the field to a `RoborockDataProtocol` + value used to check support against the product schema. """ MAIN_BRUSH_WORK_TIME = "main_brush_work_time" From 29b41e64b57cc946c74b430b1be64558f46c910c Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 17:31:16 -0700 Subject: [PATCH 4/7] feat: Implement direct device trait updates from data protocol messages using `dps` metadata and add corresponding update listeners. This uses the same dps converter patern used by q10, but does not share code explicitly. --- roborock/protocols/v1_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 8f039bb0..4cdee6f5 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -170,7 +170,7 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not message.payload: + if not (datapoints := _decode_dps_message(message)): return ResponseMessage(request_id=message.seq, data={}) if (datapoints := _decode_dps_message(message)) is None: From 64f9d2ed95bc8cb2afaee156c245ec57eee50b40 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 22 Mar 2026 18:08:15 -0700 Subject: [PATCH 5/7] chore: Apply co-pilot feedback --- roborock/protocols/v1_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 4cdee6f5..8f039bb0 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -170,7 +170,7 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: response, as long as we can extract the request ID. This is so we can associate an API response with a request even if there was an error. """ - if not (datapoints := _decode_dps_message(message)): + if not message.payload: return ResponseMessage(request_id=message.seq, data={}) if (datapoints := _decode_dps_message(message)) is None: From 8e681a28fbf5ee671749248ed90eccbc5d332049 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 26 Apr 2026 16:17:26 -0700 Subject: [PATCH 6/7] clean up duplicated traits --- roborock/devices/traits/v1/common.py | 76 +---------------------- roborock/devices/traits/v1/consumeable.py | 5 +- roborock/devices/traits/v1/status.py | 4 +- 3 files changed, 5 insertions(+), 80 deletions(-) diff --git a/roborock/devices/traits/v1/common.py b/roborock/devices/traits/v1/common.py index 22792d98..431cd075 100644 --- a/roborock/devices/traits/v1/common.py +++ b/roborock/devices/traits/v1/common.py @@ -5,15 +5,12 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Callable from dataclasses import fields -from typing import Any, ClassVar +from typing import ClassVar -from roborock.callbacks import CallbackList from roborock.data import RoborockBase from roborock.exceptions import RoborockParsingException from roborock.protocols.v1_protocol import V1RpcChannel -from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand _LOGGER = logging.getLogger(__name__) @@ -185,74 +182,3 @@ def wrapper(*args, **kwargs): cls.map_rpc_channel = True # type: ignore[attr-defined] return wrapper - - -# TODO(allenporter): Merge with roborock.devices.traits.b01.q10.common.TraitUpdateListener -class TraitUpdateListener(ABC): - """Trait update listener. - - This is a base class for traits to support notifying listeners when they - have been updated. Clients may register callbacks to be notified when the - trait has been updated. When the listener callback is invoked, the client - should read the trait's properties to get the updated values. - """ - - def __init__(self, logger: logging.Logger) -> None: - """Initialize the trait update listener.""" - self._update_callbacks: CallbackList[None] = CallbackList(logger=logger) - - def add_update_listener(self, callback: Callable[[], None]) -> Callable[[], None]: - """Register a callback when the trait has been updated. - - Returns a callable to remove the listener. - """ - # We wrap the callback to ignore the value passed to it. - return self._update_callbacks.add_callback(lambda _: callback()) - - def _notify_update(self) -> None: - """Notify all update listeners.""" - self._update_callbacks(None) - - -class DpsDataConverter: - """Utility to handle the transformation and merging of DPS data into models. - - This class pre-calculates the mapping between Data Point IDs and dataclass fields - to optimize repeated updates from device streams. - """ - - def __init__(self, dps_type_map: dict[RoborockDataProtocol, type], dps_field_map: dict[RoborockDataProtocol, str]): - """Initialize the converter for a specific RoborockBase-derived class.""" - self._dps_type_map = dps_type_map - self._dps_field_map = dps_field_map - - @classmethod - def from_dataclass(cls, dataclass_type: type[RoborockBase]): - """Initialize the converter for a specific RoborockBase-derived class.""" - dps_type_map: dict[RoborockDataProtocol, type] = {} - dps_field_map: dict[RoborockDataProtocol, str] = {} - for field_obj in fields(dataclass_type): - if field_obj.metadata and "dps" in field_obj.metadata: - dps_id = field_obj.metadata["dps"] - dps_type_map[dps_id] = field_obj.type - dps_field_map[dps_id] = field_obj.name - return cls(dps_type_map, dps_field_map) - - def update_from_dps(self, target: RoborockBase, decoded_dps: dict[RoborockDataProtocol, Any]) -> bool: - """Convert and merge raw DPS data into the target object. - - Uses the pre-calculated type mapping to ensure values are converted to the - correct Python types before being updated on the target. - - Args: - target: The target object to update. - decoded_dps: The decoded DPS data to convert. - - Returns: - True if any values were updated, False otherwise. - """ - conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) - for dps_id, value in conversions.items(): - field_name = self._dps_field_map[dps_id] - setattr(target, field_name, value) - return bool(conversions) diff --git a/roborock/devices/traits/v1/consumeable.py b/roborock/devices/traits/v1/consumeable.py index 5e213716..da0c4c07 100644 --- a/roborock/devices/traits/v1/consumeable.py +++ b/roborock/devices/traits/v1/consumeable.py @@ -9,19 +9,18 @@ from typing import Any, Self from roborock.data import Consumable +from roborock.devices.traits.common import DpsDataConverter, TraitUpdateListener from roborock.devices.traits.v1 import common from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand -from .common import TraitUpdateListener - __all__ = [ "ConsumableTrait", ] _LOGGER = logging.getLogger(__name__) -_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable) +_DPS_CONVERTER = DpsDataConverter.from_dataclass(Consumable) class ConsumableAttribute(StrEnum): diff --git a/roborock/devices/traits/v1/status.py b/roborock/devices/traits/v1/status.py index 84cdcb9b..c55ae799 100644 --- a/roborock/devices/traits/v1/status.py +++ b/roborock/devices/traits/v1/status.py @@ -12,16 +12,16 @@ get_water_mode_mapping, get_water_modes, ) +from roborock.devices.traits.common import DpsDataConverter, TraitUpdateListener from roborock.roborock_message import RoborockDataProtocol from roborock.roborock_typing import RoborockCommand from . import common -from .common import TraitUpdateListener from .device_features import DeviceFeaturesTrait _LOGGER = logging.getLogger(__name__) -_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(StatusV2) +_DPS_CONVERTER = DpsDataConverter.from_dataclass(StatusV2) class StatusTrait(StatusV2, common.V1TraitMixin, TraitUpdateListener): From 4c2af3e226c47ae6868876da2016c3ca606ea36e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 26 Apr 2026 16:27:26 -0700 Subject: [PATCH 7/7] chore: address PR feedback --- roborock/devices/rpc/v1_channel.py | 17 +++++++++++------ tests/protocols/test_v1_protocol.py | 23 ++++++++++------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index 0e99e2fe..f602d90d 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -308,9 +308,8 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab loop = asyncio.get_running_loop() self._reconnect_task = loop.create_task(self._background_reconnect()) - # Always attempt to subscribe to MQTT to receive protocol updates (data points) - # even if we have a local connection. Protocol updates only come via cloud/MQTT. - # Local connection is used for RPC commands, but push notifications come via MQTT. + # We maintain an active MQTT subscription even when connected locally to receive + # unsolicited status updates (DPS push messages) directly from the cloud. try: self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) except RoborockException as err: @@ -342,7 +341,7 @@ def add_dps_listener(self, listener: Callable[[dict[RoborockDataProtocol, Any]], This will attach a listener to the existing subscription, invoking the listener whenever new DPS values arrive from the subscription. This will only work if a subscription has already been setup, which is - handled by the device setup. + handled by the device start. """ return self._dps_listeners.add_callback(listener) @@ -447,10 +446,16 @@ def _on_mqtt_message(self, message: RoborockMessage) -> None: if self._callback: self._callback(message) try: - if datapoints := decode_data_protocol_message(message): - self._dps_listeners(datapoints) + datapoints = decode_data_protocol_message(message) except RoborockException as e: self._logger.debug("Error decoding data protocol message: %s", e) + return + + if datapoints: + try: + self._dps_listeners(datapoints) + except Exception: + self._logger.exception("Error in DPS listener callback") def _on_local_message(self, message: RoborockMessage) -> None: """Handle incoming local messages.""" diff --git a/tests/protocols/test_v1_protocol.py b/tests/protocols/test_v1_protocol.py index b5454057..c4bcb35c 100644 --- a/tests/protocols/test_v1_protocol.py +++ b/tests/protocols/test_v1_protocol.py @@ -320,9 +320,7 @@ def test_decode_data_protocol_message() -> None: payload=payload, ) result = decode_data_protocol_message(message) - assert result is not None - assert result[RoborockDataProtocol.STATE] == 8 - assert result[RoborockDataProtocol.BATTERY] == 95 + assert result == {RoborockDataProtocol.STATE: 8, RoborockDataProtocol.BATTERY: 95} def test_decode_data_protocol_message_all_status_fields() -> None: @@ -338,13 +336,14 @@ def test_decode_data_protocol_message_all_status_fields() -> None: payload=payload, ) result = decode_data_protocol_message(message) - assert result is not None - assert result[RoborockDataProtocol.ERROR_CODE] == 0 - assert result[RoborockDataProtocol.STATE] == 5 - assert result[RoborockDataProtocol.BATTERY] == 100 - assert result[RoborockDataProtocol.FAN_POWER] == 102 - assert result[RoborockDataProtocol.WATER_BOX_MODE] == 204 - assert result[RoborockDataProtocol.CHARGE_STATUS] == 1 + assert result == { + RoborockDataProtocol.ERROR_CODE: 0, + RoborockDataProtocol.STATE: 5, + RoborockDataProtocol.BATTERY: 100, + RoborockDataProtocol.FAN_POWER: 102, + RoborockDataProtocol.WATER_BOX_MODE: 204, + RoborockDataProtocol.CHARGE_STATUS: 1, + } def test_decode_data_protocol_message_unknown_codes() -> None: @@ -355,9 +354,7 @@ def test_decode_data_protocol_message_unknown_codes() -> None: payload=payload, ) result = decode_data_protocol_message(message) - assert result is not None - assert len(result) == 1 - assert result[RoborockDataProtocol.STATE] == 8 + assert result == {RoborockDataProtocol.STATE: 8} def test_decode_data_protocol_message_empty_payload() -> None: