From 52764c26950db0464584cc41161c41f0c9f155d3 Mon Sep 17 00:00:00 2001 From: Fabian Peter Hammerle Date: Sat, 4 Nov 2023 13:45:21 +0100 Subject: [PATCH] replace paho-mqtt with its async wrapper aiomqtt (to prepare for upgrading PySwitchbot) https://github.com/fphammerle/switchbot-mqtt/issues/103 https://github.com/fphammerle/switchbot-mqtt/issues/180#issuecomment-1741108146 https://github.com/fphammerle/switchbot-mqtt/issues/127#issuecomment-1349244614 --- CHANGELOG.md | 4 + Pipfile | 1 + Pipfile.lock | 19 +- setup.py | 7 +- switchbot_mqtt/__init__.py | 130 ++-- switchbot_mqtt/_actors/__init__.py | 63 +- switchbot_mqtt/_actors/base.py | 125 ++-- switchbot_mqtt/_cli.py | 30 +- tests/test_actor_base.py | 18 +- tests/test_mqtt.py | 613 ++++++++++-------- tests/test_switchbot_button_automator.py | 33 +- tests/test_switchbot_curtain_motor.py | 70 +- .../test_switchbot_curtain_motor_position.py | 171 ++--- 13 files changed, 747 insertions(+), 537 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7adf46c..a023e19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - declare compatibility with `python3.11` +### Changed +- replaced [paho-mqtt](https://github.com/eclipse/paho.mqtt.python) + with its async wrapper [aiomqtt](https://github.com/sbtinstruments/aiomqtt) + ### Removed - compatibility with `python3.7` diff --git a/Pipfile b/Pipfile index 5ed7fb6..289573d 100644 --- a/Pipfile +++ b/Pipfile @@ -11,6 +11,7 @@ black = "*" mypy = "*" pylint = "*" pytest = "*" +pytest-asyncio = "*" pytest-cov = "*" # python3.10 compatibility diff --git a/Pipfile.lock b/Pipfile.lock index e8c43b1..a361aa1 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "cdf039b4e2e188227f3d34852c1dfbb5449901c9b4bc10a8379b46eebe84fb64" + "sha256": "94ad3eac5fb437c0e4a9fe45f316b813bcbc809b0cfc901ba1d885ae2c44fc67" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,14 @@ ] }, "default": { + "aiomqtt": { + "hashes": [ + "sha256:3925b40b2b95b1905753d53ef3a9162e903cfab35ebe9647ab4d52e45ffb727f", + "sha256:7582f4341f08ef7110dd9ab3a559454dc28ccda1eac502ff8f08a73b238ecede" + ], + "markers": "python_version >= '3.8' and python_version < '4.0'", + "version": "==1.2.1" + }, "bluepy": { "hashes": [ "sha256:2a71edafe103565fb990256ff3624c1653036a837dfc90e1e32b839f83971cec" @@ -275,6 +283,15 @@ "markers": "python_version >= '3.7'", "version": "==7.4.3" }, + "pytest-asyncio": { + "hashes": [ + "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d", + "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b" + ], + "index": "pypi", + "markers": "python_version >= '3.7'", + "version": "==0.21.1" + }, "pytest-cov": { "hashes": [ "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", diff --git a/setup.py b/setup.py index f177b8a..4de377b 100644 --- a/setup.py +++ b/setup.py @@ -73,8 +73,9 @@ ], entry_points={"console_scripts": ["switchbot-mqtt = switchbot_mqtt._cli:_main"]}, # >=3.6 variable type hints, f-strings, typing.Collection & * to force keyword-only arguments - # >=3.7 postponed evaluation of type annotations (PEP563) & dataclass - python_requires=">=3.8", # python<3.8 untested + # >=3.7 postponed evaluation of type annotations (PEP563) & asyncio.run + # >=3.8 unittest.mock.AsyncMock + python_requires=">=3.8", install_requires=[ # >=1.3.0 for btle.BTLEManagementError (could be replaced with BTLEException) # >=0.1.0 for btle.helperExe @@ -83,7 +84,7 @@ # >=0.10.0 for SwitchbotCurtain.{update,get_position} # >=0.9.0 for SwitchbotCurtain.set_position "PySwitchbot>=0.10.0,<0.13", - "paho-mqtt<2", + "aiomqtt<2", ], setup_requires=["setuptools_scm"], tests_require=["pytest"], diff --git a/switchbot_mqtt/__init__.py b/switchbot_mqtt/__init__.py index c5d8b87..363486d 100644 --- a/switchbot_mqtt/__init__.py +++ b/switchbot_mqtt/__init__.py @@ -18,12 +18,12 @@ import logging import socket +import ssl import typing -import paho.mqtt.client +import aiomqtt from switchbot_mqtt._actors import _ButtonAutomator, _CurtainMotor -from switchbot_mqtt._actors.base import _MQTTCallbackUserdata _LOGGER = logging.getLogger(__name__) @@ -34,34 +34,54 @@ _MQTT_LAST_WILL_PAYLOAD = "offline" -def _mqtt_on_connect( - mqtt_client: paho.mqtt.client.Client, - userdata: _MQTTCallbackUserdata, - flags: typing.Dict[str, int], - return_code: int, +async def _listen( + *, + mqtt_client: aiomqtt.Client, + topic_callbacks: typing.Iterable[typing.Tuple[str, typing.Callable]], + mqtt_topic_prefix: str, + retry_count: int, + device_passwords: typing.Dict[str, str], + fetch_device_info: bool, ) -> None: - # pylint: disable=unused-argument; callback - # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L441 - assert return_code == 0, return_code # connection accepted - mqtt_broker_host, mqtt_broker_port, *_ = mqtt_client.socket().getpeername() - # https://www.rfc-editor.org/rfc/rfc5952#section-6 - _LOGGER.debug( - "connected to MQTT broker %s:%d", - f"[{mqtt_broker_host}]" - if mqtt_client.socket().family == socket.AF_INET6 - else mqtt_broker_host, - mqtt_broker_port, - ) - mqtt_client.publish( - topic=userdata.mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC, - payload=_MQTT_BIRTH_PAYLOAD, - retain=True, - ) - _ButtonAutomator.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata) - _CurtainMotor.mqtt_subscribe(mqtt_client=mqtt_client, settings=userdata) + async with mqtt_client.messages() as messages: + await mqtt_client.publish( + topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC, + payload=_MQTT_BIRTH_PAYLOAD, + retain=True, + ) + async for message in messages: + for topic, callback in topic_callbacks: + if message.topic.matches(topic): + await callback( + mqtt_client=mqtt_client, + message=message, + mqtt_topic_prefix=mqtt_topic_prefix, + retry_count=retry_count, + device_passwords=device_passwords, + fetch_device_info=fetch_device_info, + ) + + +def _log_mqtt_connected(mqtt_client: aiomqtt.Client) -> None: + if _LOGGER.getEffectiveLevel() <= logging.DEBUG: + mqtt_socket = ( + # aiomqtt neither exposes instance of paho.mqtt.client.Client nor socket publicly. + # level condition to avoid accessing protected `mqtt_client._client` in production. + # pylint: disable=protected-access + mqtt_client._client.socket() + ) + (mqtt_broker_host, mqtt_broker_port, *_) = mqtt_socket.getpeername() + # https://github.com/sbtinstruments/aiomqtt/blob/v1.2.1/aiomqtt/client.py#L1089 + _LOGGER.debug( + "connected to MQTT broker %s:%d", + f"[{mqtt_broker_host}]" + if mqtt_socket.family == socket.AF_INET6 + else mqtt_broker_host, + mqtt_broker_port, + ) -def _run( # pylint: disable=too-many-arguments +async def _run( # pylint: disable=too-many-arguments *, mqtt_host: str, mqtt_port: int, @@ -73,33 +93,43 @@ def _run( # pylint: disable=too-many-arguments device_passwords: typing.Dict[str, str], fetch_device_info: bool, ) -> None: - # https://pypi.org/project/paho-mqtt/ - mqtt_client = paho.mqtt.client.Client( - userdata=_MQTTCallbackUserdata( - retry_count=retry_count, - device_passwords=device_passwords, - fetch_device_info=fetch_device_info, - mqtt_topic_prefix=mqtt_topic_prefix, - ) - ) - mqtt_client.on_connect = _mqtt_on_connect _LOGGER.info( "connecting to MQTT broker %s:%d (TLS %s)", mqtt_host, mqtt_port, "disabled" if mqtt_disable_tls else "enabled", ) - if not mqtt_disable_tls: - mqtt_client.tls_set(ca_certs=None) # enable tls trusting default system certs - if mqtt_username: - mqtt_client.username_pw_set(username=mqtt_username, password=mqtt_password) - elif mqtt_password: + if mqtt_password is not None and mqtt_username is None: raise ValueError("Missing MQTT username") - mqtt_client.will_set( - topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC, - payload=_MQTT_LAST_WILL_PAYLOAD, - retain=True, - ) - mqtt_client.connect(host=mqtt_host, port=mqtt_port) - # https://github.com/eclipse/paho.mqtt.python/blob/master/src/paho/mqtt/client.py#L1740 - mqtt_client.loop_forever() + async with aiomqtt.Client( # raises aiomqtt.MqttError + hostname=mqtt_host, + port=mqtt_port, + # > The settings [...] usually represent a higher security level than + # > when calling the SSLContext constructor directly. + # https://web.archive.org/web/20230714183106/https://docs.python.org/3/library/ssl.html + tls_context=None if mqtt_disable_tls else ssl.create_default_context(), + username=None if mqtt_username is None else mqtt_username, + password=None if mqtt_password is None else mqtt_password, + will=aiomqtt.Will( + topic=mqtt_topic_prefix + _MQTT_AVAILABILITY_TOPIC, + payload=_MQTT_LAST_WILL_PAYLOAD, + retain=True, + ), + ) as mqtt_client: + _log_mqtt_connected(mqtt_client=mqtt_client) + topic_callbacks: typing.List[typing.Tuple[str, typing.Callable]] = [] + for actor_class in (_ButtonAutomator, _CurtainMotor): + async for topic, callback in actor_class.mqtt_subscribe( + mqtt_client=mqtt_client, + mqtt_topic_prefix=mqtt_topic_prefix, + fetch_device_info=fetch_device_info, + ): + topic_callbacks.append((topic, callback)) + await _listen( + mqtt_client=mqtt_client, + topic_callbacks=topic_callbacks, + mqtt_topic_prefix=mqtt_topic_prefix, + retry_count=retry_count, + device_passwords=device_passwords, + fetch_device_info=fetch_device_info, + ) diff --git a/switchbot_mqtt/_actors/__init__.py b/switchbot_mqtt/_actors/__init__.py index 653ddbc..a236389 100644 --- a/switchbot_mqtt/_actors/__init__.py +++ b/switchbot_mqtt/_actors/__init__.py @@ -20,10 +20,10 @@ import typing import bluepy.btle -import paho.mqtt.client +import aiomqtt import switchbot -from switchbot_mqtt._actors.base import _MQTTCallbackUserdata, _MQTTControlledActor +from switchbot_mqtt._actors.base import _MQTTControlledActor from switchbot_mqtt._utils import ( _join_mqtt_topic_levels, _MQTTTopicLevel, @@ -69,11 +69,11 @@ def __init__( def _get_device(self) -> switchbot.SwitchbotDevice: return self.__device - def execute_command( + async def execute_command( self, *, mqtt_message_payload: bytes, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, update_device_info: bool, mqtt_topic_prefix: str, ) -> None: @@ -84,26 +84,30 @@ def execute_command( else: _LOGGER.info("switchbot %s turned on", self._mac_address) # https://www.home-assistant.io/integrations/switch.mqtt/#state_on - self.report_state( + await self.report_state( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, state=b"ON", ) if update_device_info: - self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix) + await self._update_and_report_device_info( + mqtt_client, mqtt_topic_prefix + ) # https://www.home-assistant.io/integrations/switch.mqtt/#payload_off elif mqtt_message_payload.lower() == b"off": if not self.__device.turn_off(): _LOGGER.error("failed to turn off switchbot %s", self._mac_address) else: _LOGGER.info("switchbot %s turned off", self._mac_address) - self.report_state( + await self.report_state( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, state=b"OFF", ) if update_device_info: - self._update_and_report_device_info(mqtt_client, mqtt_topic_prefix) + await self._update_and_report_device_info( + mqtt_client, mqtt_topic_prefix + ) else: _LOGGER.warning( "unexpected payload %r (expected 'ON' or 'OFF')", mqtt_message_payload @@ -154,9 +158,9 @@ def __init__( def _get_device(self) -> switchbot.SwitchbotDevice: return self.__device - def _report_position( + async def _report_position( self, - mqtt_client: paho.mqtt.client.Client, # pylint: disable=duplicate-code; similar param list + mqtt_client: aiomqtt.Client, # pylint: disable=duplicate-code; similar param list mqtt_topic_prefix: str, ) -> None: # > position_closed integer (Optional, default: 0) @@ -166,31 +170,31 @@ def _report_position( # SwitchbotCurtain.open() and .close() update the position optimistically, # SwitchbotCurtain.update() fetches the real position via bluetooth. # https://github.com/Danielhiversen/pySwitchbot/blob/0.10.0/switchbot/__init__.py#L202 - self._mqtt_publish( + await self._mqtt_publish( topic_prefix=mqtt_topic_prefix, topic_levels=self._MQTT_POSITION_TOPIC_LEVELS, payload=str(int(self.__device.get_position())).encode(), mqtt_client=mqtt_client, ) - def _update_and_report_device_info( # pylint: disable=arguments-differ; report_position is optional + async def _update_and_report_device_info( # pylint: disable=arguments-differ; report_position is optional self, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str, *, report_position: bool = True, ) -> None: - super()._update_and_report_device_info(mqtt_client, mqtt_topic_prefix) + await super()._update_and_report_device_info(mqtt_client, mqtt_topic_prefix) if report_position: - self._report_position( + await self._report_position( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix ) - def execute_command( + async def execute_command( self, *, mqtt_message_payload: bytes, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, update_device_info: bool, mqtt_topic_prefix: str, ) -> None: @@ -203,7 +207,7 @@ def execute_command( _LOGGER.info("switchbot curtain %s opening", self._mac_address) # > state_opening string (Optional, default: opening) # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening - self.report_state( + await self.report_state( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, state=b"opening", @@ -215,7 +219,7 @@ def execute_command( else: _LOGGER.info("switchbot curtain %s closing", self._mac_address) # https://www.home-assistant.io/integrations/cover.mqtt/#state_closing - self.report_state( + await self.report_state( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, state=b"closing", @@ -229,7 +233,7 @@ def execute_command( # no "stopped" state mentioned at # https://www.home-assistant.io/integrations/cover.mqtt/#configuration-variables # https://community.home-assistant.io/t/mqtt-how-to-remove-retained-messages/79029/2 - self.report_state( + await self.report_state( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, state=b"", @@ -242,18 +246,22 @@ def execute_command( mqtt_message_payload, ) if report_device_info: - self._update_and_report_device_info( + await self._update_and_report_device_info( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix, report_position=report_position, ) @classmethod - def _mqtt_set_position_callback( + async def _mqtt_set_position_callback( cls, - mqtt_client: paho.mqtt.client.Client, - userdata: _MQTTCallbackUserdata, - message: paho.mqtt.client.MQTTMessage, + *, + mqtt_client: aiomqtt.Client, + message: aiomqtt.Message, + mqtt_topic_prefix: str, + retry_count: int, + device_passwords: typing.Dict[str, str], + fetch_device_info: bool, ) -> None: # pylint: disable=unused-argument; callback # https://github.com/eclipse/paho.mqtt.python/blob/v1.6.1/src/paho/mqtt/client.py#L3556 @@ -263,11 +271,14 @@ def _mqtt_set_position_callback( return actor = cls._init_from_topic( topic=message.topic, + mqtt_topic_prefix=mqtt_topic_prefix, expected_topic_levels=cls._MQTT_SET_POSITION_TOPIC_LEVELS, - settings=userdata, + retry_count=retry_count, + device_passwords=device_passwords, ) if not actor: return # warning in _init_from_topic + assert isinstance(message.payload, bytes), message.payload position_percent = int(message.payload.decode(), 10) if position_percent < 0 or position_percent > 100: _LOGGER.warning("invalid position %u%%, ignoring message", position_percent) diff --git a/switchbot_mqtt/_actors/base.py b/switchbot_mqtt/_actors/base.py index ef8f59f..f827682 100644 --- a/switchbot_mqtt/_actors/base.py +++ b/switchbot_mqtt/_actors/base.py @@ -26,14 +26,13 @@ from __future__ import annotations # PEP563 (default in python>=3.10) import abc -import dataclasses import logging import queue import shlex import typing +import aiomqtt import bluepy.btle -import paho.mqtt.client import switchbot from switchbot_mqtt._utils import ( _join_mqtt_topic_levels, @@ -47,14 +46,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclasses.dataclass -class _MQTTCallbackUserdata: - retry_count: int - device_passwords: typing.Dict[str, str] - fetch_device_info: bool - mqtt_topic_prefix: str - - class _MQTTControlledActor(abc.ABC): MQTT_COMMAND_TOPIC_LEVELS: typing.Tuple[_MQTTTopicLevel, ...] = NotImplemented _MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS: typing.Tuple[ @@ -131,37 +122,40 @@ def _update_device_info(self) -> None: ) from exc raise - def _report_battery_level( - self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str + async def _report_battery_level( + self, mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str ) -> None: # > battery: Percentage of battery that is left. # https://www.home-assistant.io/integrations/sensor/#device-class - self._mqtt_publish( + await self._mqtt_publish( topic_prefix=mqtt_topic_prefix, topic_levels=self._MQTT_BATTERY_PERCENTAGE_TOPIC_LEVELS, payload=str(self._get_device().get_battery_percent()).encode(), mqtt_client=mqtt_client, ) - def _update_and_report_device_info( - self, mqtt_client: paho.mqtt.client.Client, mqtt_topic_prefix: str + async def _update_and_report_device_info( + self, mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str ) -> None: self._update_device_info() - self._report_battery_level( + await self._report_battery_level( mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix ) @classmethod def _init_from_topic( cls, - topic: str, + *, + topic: aiomqtt.Topic, + mqtt_topic_prefix: str, expected_topic_levels: typing.Collection[_MQTTTopicLevel], - settings: _MQTTCallbackUserdata, + retry_count: int, + device_passwords: typing.Dict[str, str], ) -> typing.Optional[_MQTTControlledActor]: try: mac_address = _parse_mqtt_topic( - topic=topic, - expected_prefix=settings.mqtt_topic_prefix, + topic=topic.value, + expected_prefix=mqtt_topic_prefix, expected_levels=expected_topic_levels, )[_MQTTTopicPlaceholder.MAC_ADDRESS] except ValueError as exc: @@ -172,17 +166,21 @@ def _init_from_topic( return None return cls( mac_address=mac_address, - retry_count=settings.retry_count, - password=settings.device_passwords.get(mac_address, None), + retry_count=retry_count, + password=device_passwords.get(mac_address, None), ) @classmethod - def _mqtt_update_device_info_callback( + async def _mqtt_update_device_info_callback( # pylint: disable=duplicate-code; other callbacks with same params cls, - mqtt_client: paho.mqtt.client.Client, - userdata: _MQTTCallbackUserdata, - message: paho.mqtt.client.MQTTMessage, + *, + mqtt_client: aiomqtt.Client, + message: aiomqtt.Message, + mqtt_topic_prefix: str, + retry_count: int, + device_passwords: typing.Dict[str, str], + fetch_device_info: bool, ) -> None: # pylint: disable=unused-argument; callback # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469 @@ -192,33 +190,39 @@ def _mqtt_update_device_info_callback( return actor = cls._init_from_topic( topic=message.topic, + mqtt_topic_prefix=mqtt_topic_prefix, expected_topic_levels=cls._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS, - settings=userdata, + retry_count=retry_count, + device_passwords=device_passwords, ) if actor: # pylint: disable=protected-access; own instance - actor._update_and_report_device_info( - mqtt_client=mqtt_client, mqtt_topic_prefix=userdata.mqtt_topic_prefix + await actor._update_and_report_device_info( + mqtt_client=mqtt_client, mqtt_topic_prefix=mqtt_topic_prefix ) @abc.abstractmethod - def execute_command( # pylint: disable=duplicate-code; implementations + async def execute_command( # pylint: disable=duplicate-code; implementations self, *, mqtt_message_payload: bytes, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, update_device_info: bool, mqtt_topic_prefix: str, ) -> None: raise NotImplementedError() @classmethod - def _mqtt_command_callback( + async def _mqtt_command_callback( # pylint: disable=duplicate-code; other callbacks with same params cls, - mqtt_client: paho.mqtt.client.Client, - userdata: _MQTTCallbackUserdata, - message: paho.mqtt.client.MQTTMessage, + *, + mqtt_client: aiomqtt.Client, + message: aiomqtt.Message, + mqtt_topic_prefix: str, + retry_count: int, + device_passwords: typing.Dict[str, str], + fetch_device_info: bool, ) -> None: # pylint: disable=unused-argument; callback # https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L469 @@ -228,15 +232,18 @@ def _mqtt_command_callback( return actor = cls._init_from_topic( topic=message.topic, + mqtt_topic_prefix=mqtt_topic_prefix, expected_topic_levels=cls.MQTT_COMMAND_TOPIC_LEVELS, - settings=userdata, + retry_count=retry_count, + device_passwords=device_passwords, ) if actor: - actor.execute_command( + assert isinstance(message.payload, bytes), message.payload + await actor.execute_command( mqtt_message_payload=message.payload, mqtt_client=mqtt_client, - update_device_info=userdata.fetch_device_info, - mqtt_topic_prefix=userdata.mqtt_topic_prefix, + update_device_info=fetch_device_info, + mqtt_topic_prefix=mqtt_topic_prefix, ) @classmethod @@ -257,28 +264,32 @@ def _get_mqtt_message_callbacks( return callbacks @classmethod - def mqtt_subscribe( - cls, *, mqtt_client: paho.mqtt.client.Client, settings: _MQTTCallbackUserdata - ) -> None: + async def mqtt_subscribe( + cls, + *, + mqtt_client: aiomqtt.Client, + mqtt_topic_prefix: str, + fetch_device_info: bool, + ) -> typing.AsyncIterator[typing.Tuple[str, typing.Callable]]: for topic_levels, callback in cls._get_mqtt_message_callbacks( - enable_device_info_update_topic=settings.fetch_device_info + enable_device_info_update_topic=fetch_device_info ).items(): topic = _join_mqtt_topic_levels( - topic_prefix=settings.mqtt_topic_prefix, + topic_prefix=mqtt_topic_prefix, topic_levels=topic_levels, mac_address="+", ) _LOGGER.info("subscribing to MQTT topic %r", topic) - mqtt_client.subscribe(topic) - mqtt_client.message_callback_add(sub=topic, callback=callback) + await mqtt_client.subscribe(topic) + yield (topic, callback) - def _mqtt_publish( + async def _mqtt_publish( self, *, topic_prefix: str, topic_levels: typing.Iterable[_MQTTTopicLevel], payload: bytes, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, ) -> None: topic = _join_mqtt_topic_levels( topic_prefix=topic_prefix, @@ -287,24 +298,22 @@ def _mqtt_publish( ) # https://pypi.org/project/paho-mqtt/#publishing _LOGGER.debug("publishing topic=%s payload=%r", topic, payload) - message_info: paho.mqtt.client.MQTTMessageInfo = mqtt_client.publish( - topic=topic, payload=payload, retain=True - ) - # wait before checking status? - if message_info.rc != paho.mqtt.client.MQTT_ERR_SUCCESS: + try: + await mqtt_client.publish(topic=topic, payload=payload, retain=True) + except aiomqtt.MqttCodeError as exc: _LOGGER.error( - "Failed to publish MQTT message on topic %s (rc=%d)", + "Failed to publish MQTT message on topic %s: aiomqtt.MqttCodeError %s", topic, - message_info.rc, + exc, ) - def report_state( + async def report_state( self, state: bytes, - mqtt_client: paho.mqtt.client.Client, + mqtt_client: aiomqtt.Client, mqtt_topic_prefix: str, ) -> None: - self._mqtt_publish( + await self._mqtt_publish( topic_prefix=mqtt_topic_prefix, topic_levels=self.MQTT_STATE_TOPIC_LEVELS, payload=state, diff --git a/switchbot_mqtt/_cli.py b/switchbot_mqtt/_cli.py index 13a2aa3..1c0b12c 100644 --- a/switchbot_mqtt/_cli.py +++ b/switchbot_mqtt/_cli.py @@ -17,6 +17,7 @@ # along with this program. If not, see . import argparse +import asyncio import json import logging import os @@ -154,17 +155,20 @@ def _main() -> None: device_passwords = json.loads(args.device_password_path.read_text()) else: device_passwords = {} - switchbot_mqtt._run( # pylint: disable=protected-access; internal - mqtt_host=args.mqtt_host, - mqtt_port=mqtt_port, - mqtt_disable_tls=not args.mqtt_enable_tls, - mqtt_username=args.mqtt_username, - mqtt_password=mqtt_password, - mqtt_topic_prefix=args.mqtt_topic_prefix, - retry_count=args.retry_count, - device_passwords=device_passwords, - fetch_device_info=args.fetch_device_info - # > In formal language theory, the empty string, [...], is the unique string of length zero. - # https://en.wikipedia.org/wiki/Empty_string - or bool(os.environ.get("FETCH_DEVICE_INFO")), + asyncio.run( + switchbot_mqtt._run( # pylint: disable=protected-access; internal + mqtt_host=args.mqtt_host, + mqtt_port=mqtt_port, + mqtt_disable_tls=not args.mqtt_enable_tls, + mqtt_username=args.mqtt_username, + mqtt_password=mqtt_password, + mqtt_topic_prefix=args.mqtt_topic_prefix, + retry_count=args.retry_count, + device_passwords=device_passwords, + fetch_device_info=args.fetch_device_info + # > In formal language theory, the empty string, [...], + # > is the unique string of length zero. + # https://en.wikipedia.org/wiki/Empty_string + or bool(os.environ.get("FETCH_DEVICE_INFO")), + ) ) diff --git a/tests/test_actor_base.py b/tests/test_actor_base.py index 49dad35..a370f3f 100644 --- a/tests/test_actor_base.py +++ b/tests/test_actor_base.py @@ -35,7 +35,8 @@ def test_abstract() -> None: ) -def test_execute_command_abstract() -> None: +@pytest.mark.asyncio +async def test_execute_command_abstract() -> None: class _ActorMock(switchbot_mqtt._actors.base._MQTTControlledActor): # pylint: disable=duplicate-code def __init__( @@ -45,7 +46,7 @@ def __init__( mac_address=mac_address, retry_count=retry_count, password=password ) - def execute_command( + async def execute_command( self, *, mqtt_message_payload: bytes, @@ -54,7 +55,7 @@ def execute_command( mqtt_topic_prefix: str, ) -> None: assert 21 - super().execute_command( + await super().execute_command( # type: ignore mqtt_message_payload=mqtt_message_payload, mqtt_client=mqtt_client, update_device_info=update_device_info, @@ -65,9 +66,18 @@ def _get_device(self) -> switchbot.SwitchbotDevice: assert 42 return super()._get_device() + with pytest.raises(TypeError) as exc_info: + # pylint: disable=abstract-class-instantiated + switchbot_mqtt._actors.base._MQTTControlledActor( # type: ignore + mac_address="aa:bb:cc:dd:ee:ff", retry_count=42, password=None + ) + exc_info.match( + r"^Can't instantiate abstract class _MQTTControlledActor" + r" with abstract methods __init__, _get_device, execute_command$" + ) actor = _ActorMock(mac_address="aa:bb:cc:dd:ee:ff", retry_count=42, password=None) with pytest.raises(NotImplementedError): - actor.execute_command( + await actor.execute_command( mqtt_message_payload=b"dummy", mqtt_client="dummy", update_device_info=True, diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index 1b105be..0dda324 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -18,24 +18,85 @@ import logging import socket +import ssl import typing import unittest.mock import _pytest.logging # pylint: disable=import-private-name; typing import pytest -from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE, MQTT_ERR_SUCCESS, MQTTMessage, Client +import aiomqtt +from paho.mqtt.client import MQTT_ERR_NO_CONN # pylint: disable=import-private-name; internal import switchbot_mqtt import switchbot_mqtt._actors from switchbot_mqtt._actors import _ButtonAutomator, _CurtainMotor -from switchbot_mqtt._actors.base import _MQTTCallbackUserdata, _MQTTControlledActor +from switchbot_mqtt._actors.base import _MQTTControlledActor from switchbot_mqtt._utils import _MQTTTopicLevel, _MQTTTopicPlaceholder # pylint: disable=protected-access # pylint: disable=too-many-arguments; these are tests, no API +@pytest.mark.asyncio +async def test__listen(caplog: _pytest.logging.LogCaptureFixture) -> None: + mqtt_client = unittest.mock.AsyncMock() + messages_mock = unittest.mock.AsyncMock() + + async def _msg_iter() -> typing.AsyncIterator[aiomqtt.Message]: + for topic, payload in [ + ("/foo", b"foo1"), + ("/baz/21/bar", b"42/2"), + ("/baz/bar", b"nope"), + ("/foo", b"foo2"), + ]: + yield aiomqtt.Message( + topic=topic, + payload=payload, + qos=0, + retain=False, + mid=0, + properties=None, + ) + + messages_mock.__aenter__.return_value.__aiter__.side_effect = _msg_iter + mqtt_client.messages = lambda: messages_mock + callback_foo = unittest.mock.AsyncMock() + callback_bar = unittest.mock.AsyncMock() + with caplog.at_level(logging.DEBUG): + await switchbot_mqtt._listen( + mqtt_client=mqtt_client, + topic_callbacks=(("/foo", callback_foo), ("/baz/+/bar", callback_bar)), + mqtt_topic_prefix="whatever/", + retry_count=3, + device_passwords={}, + fetch_device_info=False, + ) + mqtt_client.publish.assert_awaited_once_with( + topic="whatever/switchbot-mqtt/status", payload="online", retain=True + ) + messages_mock.__aenter__.assert_awaited_once_with() + assert callback_foo.await_count == 2 + assert not callback_foo.await_args_list[0].args + kwargs = callback_foo.await_args_list[0].kwargs + assert kwargs["message"].topic.value == "/foo" + assert kwargs["message"].payload == b"foo1" + del kwargs["message"] # type: ignore + assert kwargs == { + "mqtt_client": mqtt_client, + "mqtt_topic_prefix": "whatever/", + "retry_count": 3, + "device_passwords": {}, + "fetch_device_info": False, + } + assert callback_foo.await_args_list[1].kwargs["message"].payload == b"foo2" + assert callback_bar.await_count == 1 + assert ( + callback_bar.await_args_list[0].kwargs["message"].topic.value == "/baz/21/bar" + ) + assert callback_bar.await_args_list[0].kwargs["message"].payload == b"42/2" + + @pytest.mark.parametrize( ("socket_family", "peername", "peername_log"), [ @@ -44,70 +105,37 @@ (socket.AF_INET6, ("::1", 1883, 0, 0), "[::1]:1883"), ], ) -def test__mqtt_on_connect( +def test__log_mqtt_connected( caplog: _pytest.logging.LogCaptureFixture, socket_family: int, # socket.AddressFamily, peername: typing.Tuple[typing.Union[str, int]], peername_log: str, ) -> None: mqtt_client = unittest.mock.MagicMock() - mqtt_client.socket().family = socket_family - mqtt_client.socket().getpeername.return_value = peername + mqtt_client._client.socket().family = socket_family + mqtt_client._client.socket().getpeername.return_value = peername + with caplog.at_level(logging.INFO): + switchbot_mqtt._log_mqtt_connected(mqtt_client) + assert not caplog.records with caplog.at_level(logging.DEBUG): - switchbot_mqtt._mqtt_on_connect( - mqtt_client, - _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="whatever/", - ), - {}, - 0, - ) - mqtt_client.publish.assert_called_once_with( - topic="whatever/switchbot-mqtt/status", payload="online", retain=True + switchbot_mqtt._log_mqtt_connected(mqtt_client) + assert caplog.record_tuples[0] == ( + "switchbot_mqtt", + logging.DEBUG, + f"connected to MQTT broker {peername_log}", ) - assert mqtt_client.subscribe.call_args_list == [ - unittest.mock.call("whatever/switch/switchbot/+/set"), - unittest.mock.call("whatever/cover/switchbot-curtain/+/set"), - unittest.mock.call("whatever/cover/switchbot-curtain/+/position/set-percent"), - ] - assert mqtt_client.message_callback_add.call_count == 3 - assert caplog.record_tuples == [ - ( - "switchbot_mqtt", - logging.DEBUG, - "connected to MQTT broker " + peername_log, - ), - ( - "switchbot_mqtt._actors.base", - logging.INFO, - "subscribing to MQTT topic 'whatever/switch/switchbot/+/set'", - ), - ( - "switchbot_mqtt._actors.base", - logging.INFO, - "subscribing to MQTT topic 'whatever/cover/switchbot-curtain/+/set'", - ), - ( - "switchbot_mqtt._actors.base", - logging.INFO, - "subscribing to MQTT topic " - "'whatever/cover/switchbot-curtain/+/position/set-percent'", - ), - ] +@pytest.mark.asyncio() @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"]) -@pytest.mark.parametrize("mqtt_port", [1833]) +@pytest.mark.parametrize("mqtt_port", [1234]) @pytest.mark.parametrize("retry_count", [3, 21]) @pytest.mark.parametrize( "device_passwords", [{}, {"11:22:33:44:55:66": "password", "aa:bb:cc:dd:ee:ff": "secret"}], ) @pytest.mark.parametrize("fetch_device_info", [True, False]) -def test__run( +async def test__run( caplog: _pytest.logging.LogCaptureFixture, mqtt_host: str, mqtt_port: int, @@ -115,96 +143,129 @@ def test__run( device_passwords: typing.Dict[str, str], fetch_device_info: bool, ) -> None: - with unittest.mock.patch( - "paho.mqtt.client.Client" - ) as mqtt_client_mock, caplog.at_level(logging.DEBUG): - switchbot_mqtt._run( + with unittest.mock.patch("aiomqtt.Client") as mqtt_client_mock, unittest.mock.patch( + "switchbot_mqtt._log_mqtt_connected" + ) as log_connected_mock, unittest.mock.patch( + "switchbot_mqtt._listen" + ) as listen_mock, caplog.at_level( + logging.DEBUG + ): + await switchbot_mqtt._run( mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_disable_tls=False, mqtt_username=None, mqtt_password=None, - mqtt_topic_prefix="homeassistant/", + mqtt_topic_prefix="home/", retry_count=retry_count, device_passwords=device_passwords, fetch_device_info=fetch_device_info, ) mqtt_client_mock.assert_called_once() - assert not mqtt_client_mock.call_args[0] - assert set(mqtt_client_mock.call_args[1].keys()) == {"userdata"} - userdata = mqtt_client_mock.call_args[1]["userdata"] - assert userdata == _MQTTCallbackUserdata( - retry_count=retry_count, - device_passwords=device_passwords, - fetch_device_info=fetch_device_info, - mqtt_topic_prefix="homeassistant/", + assert not mqtt_client_mock.call_args.args + init_kwargs = mqtt_client_mock.call_args.kwargs + assert isinstance(init_kwargs.pop("tls_context"), ssl.SSLContext) + assert init_kwargs.pop("will") == aiomqtt.Will( + topic="home/switchbot-mqtt/status", + payload="offline", + qos=0, + retain=True, + properties=None, ) - assert not mqtt_client_mock().username_pw_set.called - mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None) - mqtt_client_mock().will_set.assert_called_once_with( - topic="homeassistant/switchbot-mqtt/status", payload="offline", retain=True + assert init_kwargs == { + "hostname": mqtt_host, + "port": mqtt_port, + "username": None, + "password": None, + } + log_connected_mock.assert_called_once() + subscribe_mock = mqtt_client_mock().__aenter__.return_value.subscribe + assert subscribe_mock.await_count == (5 if fetch_device_info else 3) + subscribe_mock.assert_has_awaits( + ( + unittest.mock.call(topic) + for topic in [ + "home/switch/switchbot/+/set", + "home/cover/switchbot-curtain/+/set", + "home/cover/switchbot-curtain/+/position/set-percent", + ] + ), + any_order=True, ) - mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port) - mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port) - with caplog.at_level(logging.DEBUG): - mqtt_client_mock().on_connect(mqtt_client_mock(), userdata, {}, 0) - subscribe_mock = mqtt_client_mock().subscribe - assert subscribe_mock.call_count == (5 if fetch_device_info else 3) - for topic in [ - "homeassistant/switch/switchbot/+/set", - "homeassistant/cover/switchbot-curtain/+/set", - "homeassistant/cover/switchbot-curtain/+/position/set-percent", - ]: - assert unittest.mock.call(topic) in subscribe_mock.call_args_list - for topic in [ - "homeassistant/switch/switchbot/+/request-device-info", - "homeassistant/cover/switchbot-curtain/+/request-device-info", - ]: + if fetch_device_info: + subscribe_mock.assert_has_awaits( + ( + unittest.mock.call("home/switch/switchbot/+/request-device-info"), + unittest.mock.call( + "home/cover/switchbot-curtain/+/request-device-info" + ), + ), + any_order=True, + ) + listen_mock.assert_awaited_once() + assert listen_mock.await_args is not None # for mypy + assert not listen_mock.await_args.args + listen_kwargs = listen_mock.await_args.kwargs + assert ( + listen_kwargs.pop("mqtt_client") # type: ignore + == mqtt_client_mock().__aenter__.return_value + ) + topic_callbacks = listen_kwargs.pop("topic_callbacks") # type: ignore + assert len(topic_callbacks) == (5 if fetch_device_info else 3) + assert ( + "home/switch/switchbot/+/set", + switchbot_mqtt._actors._ButtonAutomator._mqtt_command_callback, + ) in topic_callbacks + assert ( + "home/cover/switchbot-curtain/+/set", + switchbot_mqtt._actors._CurtainMotor._mqtt_command_callback, + ) in topic_callbacks + assert ( + "home/cover/switchbot-curtain/+/position/set-percent", + switchbot_mqtt._actors._CurtainMotor._mqtt_set_position_callback, + ) in topic_callbacks + if fetch_device_info: + assert ( + "home/switch/switchbot/+/request-device-info", + switchbot_mqtt._actors._ButtonAutomator._mqtt_update_device_info_callback, + ) in topic_callbacks assert ( - unittest.mock.call(topic) in subscribe_mock.call_args_list - ) == fetch_device_info - callbacks = { - c[1]["sub"]: c[1]["callback"] - for c in mqtt_client_mock().message_callback_add.call_args_list + "home/cover/switchbot-curtain/+/request-device-info", + switchbot_mqtt._actors._CurtainMotor._mqtt_update_device_info_callback, + ) in topic_callbacks + assert listen_kwargs == { + "device_passwords": device_passwords, + "fetch_device_info": fetch_device_info, + "mqtt_topic_prefix": "home/", + "retry_count": retry_count, } - assert ( # pylint: disable=comparison-with-callable; intended - callbacks["homeassistant/cover/switchbot-curtain/+/position/set-percent"] - == _CurtainMotor._mqtt_set_position_callback + assert caplog.record_tuples[0] == ( + "switchbot_mqtt", + logging.INFO, + f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)", ) - mqtt_client_mock().loop_forever.assert_called_once_with() - assert caplog.record_tuples[:2] == [ - ( - "switchbot_mqtt", - logging.INFO, - f"connecting to MQTT broker {mqtt_host}:{mqtt_port} (TLS enabled)", - ), - ( - "switchbot_mqtt", - logging.DEBUG, - f"connected to MQTT broker {mqtt_host}:{mqtt_port}", - ), - ] - assert len(caplog.record_tuples) == (7 if fetch_device_info else 5) + assert len(caplog.record_tuples) == (5 if fetch_device_info else 3) + 1 assert ( "switchbot_mqtt._actors.base", logging.INFO, - "subscribing to MQTT topic 'homeassistant/switch/switchbot/+/set'", + "subscribing to MQTT topic 'home/switch/switchbot/+/set'", ) in caplog.record_tuples assert ( "switchbot_mqtt._actors.base", logging.INFO, - "subscribing to MQTT topic 'homeassistant/cover/switchbot-curtain/+/set'", + "subscribing to MQTT topic 'home/cover/switchbot-curtain/+/set'", ) in caplog.record_tuples +@pytest.mark.asyncio @pytest.mark.parametrize("mqtt_disable_tls", [True, False]) -def test__run_tls( +async def test__run_tls( caplog: _pytest.logging.LogCaptureFixture, mqtt_disable_tls: bool ) -> None: - with unittest.mock.patch( - "paho.mqtt.client.Client" - ) as mqtt_client_mock, caplog.at_level(logging.INFO): - switchbot_mqtt._run( + with unittest.mock.patch("aiomqtt.Client") as mqtt_client_mock, unittest.mock.patch( + "switchbot_mqtt._listen" + ), caplog.at_level(logging.INFO): + await switchbot_mqtt._run( mqtt_host="mqtt.local", mqtt_port=1234, mqtt_disable_tls=mqtt_disable_tls, @@ -215,28 +276,32 @@ def test__run_tls( device_passwords={}, fetch_device_info=True, ) + mqtt_client_mock.assert_called_once() + assert not mqtt_client_mock.call_args.args + kwargs = mqtt_client_mock.call_args.kwargs if mqtt_disable_tls: - mqtt_client_mock().tls_set.assert_not_called() - else: - mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None) - if mqtt_disable_tls: + assert kwargs["tls_context"] is None assert caplog.record_tuples[0][2].endswith(" (TLS disabled)") else: + assert isinstance(kwargs["tls_context"], ssl.SSLContext) assert caplog.record_tuples[0][2].endswith(" (TLS enabled)") +@pytest.mark.asyncio @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"]) @pytest.mark.parametrize("mqtt_port", [1833]) @pytest.mark.parametrize("mqtt_username", ["me"]) @pytest.mark.parametrize("mqtt_password", [None, "secret"]) -def test__run_authentication( +async def test__run_authentication( mqtt_host: str, mqtt_port: int, mqtt_username: str, mqtt_password: typing.Optional[str], ) -> None: - with unittest.mock.patch("paho.mqtt.client.Client") as mqtt_client_mock: - switchbot_mqtt._run( + with unittest.mock.patch("aiomqtt.Client") as mqtt_client_mock, unittest.mock.patch( + "switchbot_mqtt._listen" + ): + await switchbot_mqtt._run( mqtt_host=mqtt_host, mqtt_port=mqtt_port, mqtt_disable_tls=True, @@ -247,38 +312,32 @@ def test__run_authentication( device_passwords={}, fetch_device_info=True, ) - mqtt_client_mock.assert_called_once_with( - userdata=_MQTTCallbackUserdata( - retry_count=7, - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="prfx", - ) - ) - mqtt_client_mock().username_pw_set.assert_called_once_with( - username=mqtt_username, password=mqtt_password - ) + mqtt_client_mock.assert_called_once() + assert not mqtt_client_mock.call_args.args + kwargs = mqtt_client_mock.call_args.kwargs + assert kwargs["username"] == mqtt_username + assert kwargs["password"] == mqtt_password +@pytest.mark.asyncio @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"]) @pytest.mark.parametrize("mqtt_port", [1833]) @pytest.mark.parametrize("mqtt_password", ["secret"]) -def test__run_authentication_missing_username( +async def test__run_authentication_missing_username( mqtt_host: str, mqtt_port: int, mqtt_password: str ) -> None: - with unittest.mock.patch("paho.mqtt.client.Client"): - with pytest.raises(ValueError): - switchbot_mqtt._run( - mqtt_host=mqtt_host, - mqtt_port=mqtt_port, - mqtt_disable_tls=True, - mqtt_username=None, - mqtt_password=mqtt_password, - mqtt_topic_prefix="whatever", - retry_count=3, - device_passwords={}, - fetch_device_info=True, - ) + with pytest.raises(ValueError, match=r"^Missing MQTT username$"): + await switchbot_mqtt._run( + mqtt_host=mqtt_host, + mqtt_port=mqtt_port, + mqtt_disable_tls=True, + mqtt_username=None, + mqtt_password=mqtt_password, + mqtt_topic_prefix="whatever", + retry_count=3, + device_passwords={}, + fetch_device_info=True, + ) def _mock_actor_class( @@ -297,11 +356,11 @@ def __init__( mac_address=mac_address, retry_count=retry_count, password=password ) - def execute_command( + async def execute_command( self, *, mqtt_message_payload: bytes, - mqtt_client: Client, + mqtt_client: aiomqtt.Client, update_device_info: bool, mqtt_topic_prefix: str, ) -> None: @@ -313,32 +372,28 @@ def _get_device(self) -> None: return _ActorMock +@pytest.mark.asyncio @pytest.mark.parametrize( ("topic_levels", "topic", "expected_mac_address"), [ ( switchbot_mqtt._actors._ButtonAutomator._MQTT_UPDATE_DEVICE_INFO_TOPIC_LEVELS, - b"prfx/switch/switchbot/aa:bb:cc:dd:ee:ff/request-device-info", + "prfx/switch/switchbot/aa:bb:cc:dd:ee:ff/request-device-info", "aa:bb:cc:dd:ee:ff", ), ], ) @pytest.mark.parametrize("payload", [b"", b"whatever"]) -def test__mqtt_update_device_info_callback( +async def test__mqtt_update_device_info_callback( caplog: _pytest.logging.LogCaptureFixture, topic_levels: typing.Tuple[_MQTTTopicLevel, ...], - topic: bytes, + topic: str, expected_mac_address: str, payload: bytes, ) -> None: ActorMock = _mock_actor_class(request_info_levels=topic_levels) - message = MQTTMessage(topic=topic) - message.payload = payload - callback_userdata = _MQTTCallbackUserdata( - retry_count=21, # tested in test__mqtt_command_callback - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="prfx/", + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None @@ -347,8 +402,13 @@ def test__mqtt_update_device_info_callback( ) as update_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_update_device_info_callback( - "client_dummy", callback_userdata, message + await ActorMock._mqtt_update_device_info_callback( + mqtt_client="client_dummy", + message=message, + mqtt_topic_prefix="prfx/", + retry_count=21, # tested in test__mqtt_command_callback + device_passwords={}, + fetch_device_info=True, ) init_mock.assert_called_once_with( mac_address=expected_mac_address, retry_count=21, password=None @@ -360,20 +420,26 @@ def test__mqtt_update_device_info_callback( ( "switchbot_mqtt._actors.base", logging.DEBUG, - f"received topic={topic.decode()} payload={payload!r}", + f"received topic={topic} payload={payload!r}", ) ] -def test__mqtt_update_device_info_callback_ignore_retained( +@pytest.mark.asyncio +async def test__mqtt_update_device_info_callback_ignore_retained( caplog: _pytest.logging.LogCaptureFixture, ) -> None: ActorMock = _mock_actor_class( request_info_levels=(_MQTTTopicPlaceholder.MAC_ADDRESS, "request") ) - message = MQTTMessage(topic=b"aa:bb:cc:dd:ee:ff/request") - message.payload = b"" - message.retain = True + message = aiomqtt.Message( + topic="aa:bb:cc:dd:ee:ff/request", + payload=b"", + qos=0, + retain=True, + mid=0, + properties=None, + ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None ) as init_mock, unittest.mock.patch.object( @@ -381,18 +447,17 @@ def test__mqtt_update_device_info_callback_ignore_retained( ) as execute_command_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_update_device_info_callback( - "client_dummy", - _MQTTCallbackUserdata( - retry_count=21, - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="ignored", - ), - message, + await ActorMock._mqtt_update_device_info_callback( + mqtt_client="client_dummy", + message=message, + mqtt_topic_prefix="ignored", + retry_count=21, + device_passwords={}, + fetch_device_info=True, ) init_mock.assert_not_called() execute_command_mock.assert_not_called() + execute_command_mock.assert_not_awaited() assert caplog.record_tuples == [ ( "switchbot_mqtt._actors.base", @@ -403,6 +468,7 @@ def test__mqtt_update_device_info_callback_ignore_retained( ] +@pytest.mark.asyncio @pytest.mark.parametrize( ( "topic_prefix", @@ -415,49 +481,49 @@ def test__mqtt_update_device_info_callback_ignore_retained( ( "homeassistant/", _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS, - b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", + "homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"ON", "aa:bb:cc:dd:ee:ff", ), ( "homeassistant/", _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS, - b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", + "homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"OFF", "aa:bb:cc:dd:ee:ff", ), ( "homeassistant/", _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS, - b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", + "homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"on", "aa:bb:cc:dd:ee:ff", ), ( "homeassistant/", _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS, - b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", + "homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"off", "aa:bb:cc:dd:ee:ff", ), ( "prefix-", _ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS, - b"prefix-switch/switchbot/aa:01:23:45:67:89/set", + "prefix-switch/switchbot/aa:01:23:45:67:89/set", b"ON", "aa:01:23:45:67:89", ), ( "", ["switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS], - b"switchbot/aa:01:23:45:67:89", + "switchbot/aa:01:23:45:67:89", b"ON", "aa:01:23:45:67:89", ), ( "homeassistant/", _CurtainMotor.MQTT_COMMAND_TOPIC_LEVELS, - b"homeassistant/cover/switchbot-curtain/aa:01:23:45:67:89/set", + "homeassistant/cover/switchbot-curtain/aa:01:23:45:67:89/set", b"OPEN", "aa:01:23:45:67:89", ), @@ -465,24 +531,19 @@ def test__mqtt_update_device_info_callback_ignore_retained( ) @pytest.mark.parametrize("retry_count", (3, 42)) @pytest.mark.parametrize("fetch_device_info", [True, False]) -def test__mqtt_command_callback( +async def test__mqtt_command_callback( caplog: _pytest.logging.LogCaptureFixture, topic_prefix: str, command_topic_levels: typing.Tuple[_MQTTTopicLevel, ...], - topic: bytes, + topic: str, payload: bytes, expected_mac_address: str, retry_count: int, fetch_device_info: bool, ) -> None: ActorMock = _mock_actor_class(command_topic_levels=command_topic_levels) - message = MQTTMessage(topic=topic) - message.payload = payload - callback_userdata = _MQTTCallbackUserdata( - retry_count=retry_count, - device_passwords={}, - fetch_device_info=fetch_device_info, - mqtt_topic_prefix=topic_prefix, + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None @@ -491,11 +552,18 @@ def test__mqtt_command_callback( ) as execute_command_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message) + await ActorMock._mqtt_command_callback( + mqtt_client="client_dummy", + message=message, + retry_count=retry_count, + device_passwords={}, + fetch_device_info=fetch_device_info, + mqtt_topic_prefix=topic_prefix, + ) init_mock.assert_called_once_with( mac_address=expected_mac_address, retry_count=retry_count, password=None ) - execute_command_mock.assert_called_once_with( + execute_command_mock.assert_awaited_once_with( mqtt_client="client_dummy", mqtt_message_payload=payload, update_device_info=fetch_device_info, @@ -505,11 +573,12 @@ def test__mqtt_command_callback( ( "switchbot_mqtt._actors.base", logging.DEBUG, - f"received topic={topic.decode()} payload={payload!r}", + f"received topic={topic} payload={payload!r}", ) ] +@pytest.mark.asyncio @pytest.mark.parametrize( ("mac_address", "expected_password"), [ @@ -518,34 +587,41 @@ def test__mqtt_command_callback( ("11:22:33:dd:ee:ff", "äöü"), ], ) -def test__mqtt_command_callback_password( +async def test__mqtt_command_callback_password( mac_address: str, expected_password: typing.Optional[str] ) -> None: ActorMock = _mock_actor_class( command_topic_levels=("switchbot", _MQTTTopicPlaceholder.MAC_ADDRESS) ) - message = MQTTMessage(topic=b"prefix-switchbot/" + mac_address.encode()) - message.payload = b"whatever" - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={ - "11:22:33:44:55:77": "test", - "aa:bb:cc:dd:ee:ff": "secret", - "11:22:33:dd:ee:ff": "äöü", - }, - fetch_device_info=True, - mqtt_topic_prefix="prefix-", + message = aiomqtt.Message( + topic="prefix-switchbot/" + mac_address, + payload=b"whatever", + qos=0, + retain=False, + mid=0, + properties=None, ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None ) as init_mock, unittest.mock.patch.object( ActorMock, "execute_command" ) as execute_command_mock: - ActorMock._mqtt_command_callback("client_dummy", callback_userdata, message) + await ActorMock._mqtt_command_callback( + mqtt_client="client_dummy", + message=message, + retry_count=3, + device_passwords={ + "11:22:33:44:55:77": "test", + "aa:bb:cc:dd:ee:ff": "secret", + "11:22:33:dd:ee:ff": "äöü", + }, + fetch_device_info=True, + mqtt_topic_prefix="prefix-", + ) init_mock.assert_called_once_with( mac_address=mac_address, retry_count=3, password=expected_password ) - execute_command_mock.assert_called_once_with( + execute_command_mock.assert_awaited_once_with( mqtt_client="client_dummy", mqtt_message_payload=b"whatever", update_device_info=True, @@ -553,22 +629,24 @@ def test__mqtt_command_callback_password( ) +@pytest.mark.asyncio @pytest.mark.parametrize( ("topic", "payload"), [ - (b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff", b"on"), - (b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/change", b"ON"), - (b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set/suffix", b"ON"), + ("homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff", b"on"), + ("homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/change", b"ON"), + ("homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set/suffix", b"ON"), ], ) -def test__mqtt_command_callback_unexpected_topic( - caplog: _pytest.logging.LogCaptureFixture, topic: bytes, payload: bytes +async def test__mqtt_command_callback_unexpected_topic( + caplog: _pytest.logging.LogCaptureFixture, topic: str, payload: bytes ) -> None: ActorMock = _mock_actor_class( command_topic_levels=_ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS ) - message = MQTTMessage(topic=topic) - message.payload = payload + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None + ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None ) as init_mock, unittest.mock.patch.object( @@ -576,42 +654,43 @@ def test__mqtt_command_callback_unexpected_topic( ) as execute_command_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_command_callback( - "client_dummy", - _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="homeassistant/", - ), - message, + await ActorMock._mqtt_command_callback( + mqtt_client="client_dummy", + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=True, + mqtt_topic_prefix="homeassistant/", ) init_mock.assert_not_called() execute_command_mock.assert_not_called() + execute_command_mock.assert_not_awaited() assert caplog.record_tuples == [ ( "switchbot_mqtt._actors.base", logging.DEBUG, - f"received topic={topic.decode()} payload={payload!r}", + f"received topic={topic} payload={payload!r}", ), ( "switchbot_mqtt._actors.base", logging.WARNING, - f"unexpected topic {topic.decode()}", + f"unexpected topic {topic}", ), ] +@pytest.mark.asyncio @pytest.mark.parametrize(("mac_address", "payload"), [("aa:01:23:4E:RR:OR", b"ON")]) -def test__mqtt_command_callback_invalid_mac_address( +async def test__mqtt_command_callback_invalid_mac_address( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, payload: bytes ) -> None: ActorMock = _mock_actor_class( command_topic_levels=_ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS ) - topic = f"mqttprefix-switch/switchbot/{mac_address}/set".encode() - message = MQTTMessage(topic=topic) - message.payload = payload + topic = f"mqttprefix-switch/switchbot/{mac_address}/set" + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None + ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None ) as init_mock, unittest.mock.patch.object( @@ -619,15 +698,13 @@ def test__mqtt_command_callback_invalid_mac_address( ) as execute_command_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_command_callback( - "client_dummy", - _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="mqttprefix-", - ), - message, + await ActorMock._mqtt_command_callback( + mqtt_client="client_dummy", + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=True, + mqtt_topic_prefix="mqttprefix-", ) init_mock.assert_not_called() execute_command_mock.assert_not_called() @@ -635,7 +712,7 @@ def test__mqtt_command_callback_invalid_mac_address( ( "switchbot_mqtt._actors.base", logging.DEBUG, - f"received topic={topic.decode()} payload={payload!r}", + f"received topic={topic} payload={payload!r}", ), ( "switchbot_mqtt._actors.base", @@ -645,19 +722,20 @@ def test__mqtt_command_callback_invalid_mac_address( ] +@pytest.mark.asyncio @pytest.mark.parametrize( ("topic", "payload"), - [(b"homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"ON")], + [("homeassistant/switch/switchbot/aa:bb:cc:dd:ee:ff/set", b"ON")], ) -def test__mqtt_command_callback_ignore_retained( - caplog: _pytest.logging.LogCaptureFixture, topic: bytes, payload: bytes +async def test__mqtt_command_callback_ignore_retained( + caplog: _pytest.logging.LogCaptureFixture, topic: str, payload: bytes ) -> None: ActorMock = _mock_actor_class( command_topic_levels=_ButtonAutomator.MQTT_COMMAND_TOPIC_LEVELS ) - message = MQTTMessage(topic=topic) - message.payload = payload - message.retain = True + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=True, mid=0, properties=None + ) with unittest.mock.patch.object( ActorMock, "__init__", return_value=None ) as init_mock, unittest.mock.patch.object( @@ -665,28 +743,28 @@ def test__mqtt_command_callback_ignore_retained( ) as execute_command_mock, caplog.at_level( logging.DEBUG ): - ActorMock._mqtt_command_callback( - "client_dummy", - _MQTTCallbackUserdata( - retry_count=4, - device_passwords={}, - fetch_device_info=True, - mqtt_topic_prefix="homeassistant/", - ), - message, + await ActorMock._mqtt_command_callback( + mqtt_client="client_dummy", + message=message, + retry_count=4, + device_passwords={}, + fetch_device_info=True, + mqtt_topic_prefix="homeassistant/", ) init_mock.assert_not_called() execute_command_mock.assert_not_called() + execute_command_mock.assert_not_awaited() assert caplog.record_tuples == [ ( "switchbot_mqtt._actors.base", logging.DEBUG, - f"received topic={topic.decode()} payload={payload!r}", + f"received topic={topic} payload={payload!r}", ), ("switchbot_mqtt._actors.base", logging.INFO, "ignoring retained message"), ] +@pytest.mark.asyncio @pytest.mark.parametrize( ("topic_prefix", "state_topic_levels", "mac_address", "expected_topic"), # https://www.home-assistant.io/docs/mqtt/discovery/#switches @@ -706,15 +784,15 @@ def test__mqtt_command_callback_ignore_retained( ], ) @pytest.mark.parametrize("state", [b"ON", b"CLOSE"]) -@pytest.mark.parametrize("return_code", [MQTT_ERR_SUCCESS, MQTT_ERR_QUEUE_SIZE]) -def test__report_state( +@pytest.mark.parametrize("mqtt_publish_fails", [False, True]) +async def test__report_state( caplog: _pytest.logging.LogCaptureFixture, topic_prefix: str, state_topic_levels: typing.Tuple[_MQTTTopicLevel, ...], mac_address: str, expected_topic: str, state: bytes, - return_code: int, + mqtt_publish_fails: bool, ) -> None: # pylint: disable=too-many-arguments class _ActorMock(_MQTTControlledActor): @@ -727,11 +805,11 @@ def __init__( mac_address=mac_address, retry_count=retry_count, password=password ) - def execute_command( + async def execute_command( self, *, mqtt_message_payload: bytes, - mqtt_client: Client, + mqtt_client: aiomqtt.Client, update_device_info: bool, mqtt_topic_prefix: str, ) -> None: @@ -740,16 +818,18 @@ def execute_command( def _get_device(self) -> None: return None - mqtt_client_mock = unittest.mock.MagicMock() - mqtt_client_mock.publish.return_value.rc = return_code + mqtt_client_mock = unittest.mock.AsyncMock() + if mqtt_publish_fails: + # https://github.com/sbtinstruments/aiomqtt/blob/v1.2.1/aiomqtt/client.py#L678 + mqtt_client_mock.publish.side_effect = aiomqtt.MqttCodeError( + MQTT_ERR_NO_CONN, "Could not publish message" + ) with caplog.at_level(logging.DEBUG): actor = _ActorMock(mac_address=mac_address, retry_count=3, password=None) - actor.report_state( - state=state, - mqtt_client=mqtt_client_mock, - mqtt_topic_prefix=topic_prefix, + await actor.report_state( + state=state, mqtt_client=mqtt_client_mock, mqtt_topic_prefix=topic_prefix ) - mqtt_client_mock.publish.assert_called_once_with( + mqtt_client_mock.publish.assert_awaited_once_with( topic=expected_topic, payload=state, retain=True ) assert caplog.record_tuples[0] == ( @@ -757,13 +837,14 @@ def _get_device(self) -> None: logging.DEBUG, f"publishing topic={expected_topic} payload={state!r}", ) - if return_code == MQTT_ERR_SUCCESS: + if not mqtt_publish_fails: assert not caplog.records[1:] else: assert caplog.record_tuples[1:] == [ ( "switchbot_mqtt._actors.base", logging.ERROR, - f"Failed to publish MQTT message on topic {expected_topic} (rc={return_code})", + f"Failed to publish MQTT message on topic {expected_topic}:" + " aiomqtt.MqttCodeError [code:4] The client is not currently connected.", ) ] diff --git a/tests/test_switchbot_button_automator.py b/tests/test_switchbot_button_automator.py index 28a89ab..e0d1da0 100644 --- a/tests/test_switchbot_button_automator.py +++ b/tests/test_switchbot_button_automator.py @@ -43,27 +43,29 @@ def test_get_mqtt_battery_percentage_topic(prefix: str, mac_address: str) -> Non ) +@pytest.mark.asyncio @pytest.mark.parametrize("topic_prefix", ["homeassistant/", "prefix-", ""]) @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")]) -def test__update_and_report_device_info( +async def test__update_and_report_device_info( topic_prefix: str, battery_percent: int, battery_percent_encoded: bytes ) -> None: with unittest.mock.patch("switchbot.SwitchbotCurtain.__init__", return_value=None): actor = _ButtonAutomator(mac_address="dummy", retry_count=21, password=None) actor._get_device()._switchbot_device_data = {"data": {"battery": battery_percent}} - mqtt_client_mock = unittest.mock.MagicMock() + mqtt_client_mock = unittest.mock.AsyncMock() with unittest.mock.patch("switchbot.Switchbot.update") as update_mock: - actor._update_and_report_device_info( + await actor._update_and_report_device_info( mqtt_client=mqtt_client_mock, mqtt_topic_prefix=topic_prefix ) update_mock.assert_called_once_with() - mqtt_client_mock.publish.assert_called_once_with( + mqtt_client_mock.publish.assert_awaited_once_with( topic=f"{topic_prefix}switch/switchbot/dummy/battery-percentage", payload=battery_percent_encoded, retain=True, ) +@pytest.mark.asyncio @pytest.mark.parametrize("topic_prefix", ["homeassistant/"]) @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"]) @pytest.mark.parametrize("password", (None, "secret")) @@ -81,7 +83,7 @@ def test__update_and_report_device_info( ) @pytest.mark.parametrize("update_device_info", [True, False]) @pytest.mark.parametrize("command_successful", [True, False]) -def test_execute_command( +async def test_execute_command( caplog: _pytest.logging.LogCaptureFixture, topic_prefix: str, mac_address: str, @@ -98,6 +100,7 @@ def test_execute_command( actor = _ButtonAutomator( mac_address=mac_address, retry_count=retry_count, password=password ) + mqtt_client = unittest.mock.Mock() with unittest.mock.patch.object( actor, "report_state" ) as report_mock, unittest.mock.patch( @@ -105,8 +108,8 @@ def test_execute_command( ) as action_mock, unittest.mock.patch.object( actor, "_update_and_report_device_info" ) as update_device_info_mock: - actor.execute_command( - mqtt_client="dummy", + await actor.execute_command( + mqtt_client=mqtt_client, mqtt_message_payload=message_payload, update_device_info=update_device_info, mqtt_topic_prefix=topic_prefix, @@ -124,7 +127,7 @@ def test_execute_command( ) ] report_mock.assert_called_once_with( - mqtt_client="dummy", + mqtt_client=mqtt_client, mqtt_topic_prefix=topic_prefix, state=message_payload.upper(), ) @@ -141,9 +144,10 @@ def test_execute_command( update_device_info_mock.assert_not_called() +@pytest.mark.asyncio @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"]) @pytest.mark.parametrize("message_payload", [b"EIN", b""]) -def test_execute_command_invalid_payload( +async def test_execute_command_invalid_payload( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes ) -> None: with unittest.mock.patch("switchbot.Switchbot") as device_mock, caplog.at_level( @@ -151,8 +155,8 @@ def test_execute_command_invalid_payload( ): actor = _ButtonAutomator(mac_address=mac_address, retry_count=21, password=None) with unittest.mock.patch.object(actor, "report_state") as report_mock: - actor.execute_command( - mqtt_client="dummy", + await actor.execute_command( + mqtt_client=unittest.mock.Mock(), mqtt_message_payload=message_payload, update_device_info=True, mqtt_topic_prefix="dummy", @@ -169,9 +173,10 @@ def test_execute_command_invalid_payload( ] +@pytest.mark.asyncio @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"]) @pytest.mark.parametrize("message_payload", [b"ON", b"OFF"]) -def test_execute_command_bluetooth_error( +async def test_execute_command_bluetooth_error( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes ) -> None: """ @@ -186,10 +191,10 @@ def test_execute_command_bluetooth_error( f"Failed to connect to peripheral {mac_address}, addr type: random" ), ), caplog.at_level(logging.ERROR): - _ButtonAutomator( + await _ButtonAutomator( mac_address=mac_address, retry_count=0, password=None ).execute_command( - mqtt_client="dummy", + mqtt_client=unittest.mock.Mock(), mqtt_message_payload=message_payload, update_device_info=True, mqtt_topic_prefix="dummy", diff --git a/tests/test_switchbot_curtain_motor.py b/tests/test_switchbot_curtain_motor.py index dce125d..acdbfcd 100644 --- a/tests/test_switchbot_curtain_motor.py +++ b/tests/test_switchbot_curtain_motor.py @@ -50,6 +50,7 @@ def test_get_mqtt_position_topic(mac_address: str) -> None: ) +@pytest.mark.asyncio @pytest.mark.parametrize( "mac_address", ("aa:bb:cc:dd:ee:ff", "aa:bb:cc:dd:ee:gg"), @@ -57,7 +58,7 @@ def test_get_mqtt_position_topic(mac_address: str) -> None: @pytest.mark.parametrize( ("position", "expected_payload"), [(0, b"0"), (100, b"100"), (42, b"42")] ) -def test__report_position( +async def test__report_position( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, position: int, @@ -78,13 +79,16 @@ def test__report_position( # https://github.com/Danielhiversen/pySwitchbot/blob/0.10.0/switchbot/__init__.py#L150 reverse_mode=True, ) + mqtt_client = unittest.mock.Mock() with unittest.mock.patch.object( actor, "_mqtt_publish" ) as publish_mock, unittest.mock.patch( "switchbot.SwitchbotCurtain.get_position", return_value=position ): - actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="topic-prefix") - publish_mock.assert_called_once_with( + await actor._report_position( + mqtt_client=mqtt_client, mqtt_topic_prefix="topic-prefix" + ) + publish_mock.assert_awaited_once_with( topic_prefix="topic-prefix", topic_levels=( "cover", @@ -93,13 +97,14 @@ def test__report_position( "position", ), payload=expected_payload, - mqtt_client="dummy", + mqtt_client=mqtt_client, ) assert not caplog.record_tuples +@pytest.mark.asyncio @pytest.mark.parametrize("position", ("", 'lambda: print("")')) -def test__report_position_invalid( +async def test__report_position_invalid( caplog: _pytest.logging.LogCaptureFixture, position: str ) -> None: with unittest.mock.patch( @@ -115,15 +120,18 @@ def test__report_position_invalid( ), pytest.raises( ValueError ): - actor._report_position(mqtt_client="dummy", mqtt_topic_prefix="dummy2") + await actor._report_position( + mqtt_client=unittest.mock.Mock(), mqtt_topic_prefix="dummy2" + ) publish_mock.assert_not_called() +@pytest.mark.asyncio @pytest.mark.parametrize("topic_prefix", ["", "homeassistant/"]) @pytest.mark.parametrize(("battery_percent", "battery_percent_encoded"), [(42, b"42")]) @pytest.mark.parametrize("report_position", [True, False]) @pytest.mark.parametrize(("position", "position_encoded"), [(21, b"21")]) -def test__update_and_report_device_info( +async def test__update_and_report_device_info( topic_prefix: str, report_position: bool, battery_percent: int, @@ -136,22 +144,22 @@ def test__update_and_report_device_info( actor._get_device()._switchbot_device_data = { "data": {"battery": battery_percent, "position": position} } - mqtt_client_mock = unittest.mock.MagicMock() + mqtt_client_mock = unittest.mock.AsyncMock() with unittest.mock.patch("switchbot.SwitchbotCurtain.update") as update_mock: - actor._update_and_report_device_info( + await actor._update_and_report_device_info( mqtt_client=mqtt_client_mock, mqtt_topic_prefix=topic_prefix, report_position=report_position, ) update_mock.assert_called_once_with() - assert mqtt_client_mock.publish.call_count == (1 + report_position) + assert mqtt_client_mock.publish.await_count == (1 + report_position) assert ( unittest.mock.call( topic=topic_prefix + "cover/switchbot-curtain/dummy/battery-percentage", payload=battery_percent_encoded, retain=True, ) - in mqtt_client_mock.publish.call_args_list + in mqtt_client_mock.publish.await_args_list ) if report_position: assert ( @@ -160,10 +168,11 @@ def test__update_and_report_device_info( payload=position_encoded, retain=True, ) - in mqtt_client_mock.publish.call_args_list + in mqtt_client_mock.publish.await_args_list ) +@pytest.mark.asyncio @pytest.mark.parametrize( "exception", [ @@ -171,18 +180,21 @@ def test__update_and_report_device_info( bluepy.btle.BTLEManagementError("test"), ], ) -def test__update_and_report_device_info_update_error(exception: Exception) -> None: +async def test__update_and_report_device_info_update_error( + exception: Exception, +) -> None: actor = _CurtainMotor(mac_address="dummy", retry_count=21, password=None) mqtt_client_mock = unittest.mock.MagicMock() with unittest.mock.patch.object( actor._get_device(), "update", side_effect=exception ), pytest.raises(type(exception)): - actor._update_and_report_device_info( + await actor._update_and_report_device_info( mqtt_client_mock, mqtt_topic_prefix="dummy", report_position=True ) mqtt_client_mock.publish.assert_not_called() +@pytest.mark.asyncio @pytest.mark.parametrize("topic_prefix", ["topic-prfx"]) @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff", "aa:bb:cc:11:22:33"]) @pytest.mark.parametrize("password", ["pa$$word", None]) @@ -203,7 +215,7 @@ def test__update_and_report_device_info_update_error(exception: Exception) -> No ) @pytest.mark.parametrize("update_device_info", [True, False]) @pytest.mark.parametrize("command_successful", [True, False]) -def test_execute_command( +async def test_execute_command( caplog: _pytest.logging.LogCaptureFixture, topic_prefix: str, mac_address: str, @@ -214,12 +226,14 @@ def test_execute_command( update_device_info: bool, command_successful: bool, ) -> None: + # pylint: disable=too-many-locals with unittest.mock.patch( "switchbot.SwitchbotCurtain.__init__", return_value=None ) as device_init_mock, caplog.at_level(logging.INFO): actor = _CurtainMotor( mac_address=mac_address, retry_count=retry_count, password=password ) + mqtt_client = unittest.mock.Mock() with unittest.mock.patch.object( actor, "report_state" ) as report_mock, unittest.mock.patch( @@ -227,8 +241,8 @@ def test_execute_command( ) as action_mock, unittest.mock.patch.object( actor, "_update_and_report_device_info" ) as update_device_info_mock: - actor.execute_command( - mqtt_client="dummy", + await actor.execute_command( + mqtt_client=mqtt_client, mqtt_message_payload=message_payload, update_device_info=update_device_info, mqtt_topic_prefix=topic_prefix, @@ -248,8 +262,8 @@ def test_execute_command( f"switchbot curtain {mac_address} {state_str}", ) ] - report_mock.assert_called_once_with( - mqtt_client="dummy", + report_mock.assert_awaited_once_with( + mqtt_client=mqtt_client, mqtt_topic_prefix=topic_prefix, # https://www.home-assistant.io/integrations/cover.mqtt/#state_opening state={b"open": b"opening", b"close": b"closing", b"stop": b""}[ @@ -266,8 +280,8 @@ def test_execute_command( ] report_mock.assert_not_called() if update_device_info and command_successful: - update_device_info_mock.assert_called_once_with( - mqtt_client="dummy", + update_device_info_mock.assert_awaited_once_with( + mqtt_client=mqtt_client, report_position=(action_name == "switchbot.SwitchbotCurtain.stop"), mqtt_topic_prefix=topic_prefix, ) @@ -275,10 +289,11 @@ def test_execute_command( update_device_info_mock.assert_not_called() +@pytest.mark.asyncio @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"]) @pytest.mark.parametrize("password", ["secret"]) @pytest.mark.parametrize("message_payload", [b"OEFFNEN", b""]) -def test_execute_command_invalid_payload( +async def test_execute_command_invalid_payload( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, password: str, @@ -289,8 +304,8 @@ def test_execute_command_invalid_payload( ) as device_mock, caplog.at_level(logging.INFO): actor = _CurtainMotor(mac_address=mac_address, retry_count=7, password=password) with unittest.mock.patch.object(actor, "report_state") as report_mock: - actor.execute_command( - mqtt_client="dummy", + await actor.execute_command( + mqtt_client=unittest.mock.Mock(), mqtt_message_payload=message_payload, update_device_info=True, mqtt_topic_prefix="dummy", @@ -309,9 +324,10 @@ def test_execute_command_invalid_payload( ] +@pytest.mark.asyncio @pytest.mark.parametrize("mac_address", ["aa:bb:cc:dd:ee:ff"]) @pytest.mark.parametrize("message_payload", [b"OPEN", b"CLOSE", b"STOP"]) -def test_execute_command_bluetooth_error( +async def test_execute_command_bluetooth_error( caplog: _pytest.logging.LogCaptureFixture, mac_address: str, message_payload: bytes ) -> None: """ @@ -326,10 +342,10 @@ def test_execute_command_bluetooth_error( f"Failed to connect to peripheral {mac_address}, addr type: random" ), ), caplog.at_level(logging.ERROR): - _CurtainMotor( + await _CurtainMotor( mac_address=mac_address, retry_count=0, password="secret" ).execute_command( - mqtt_client="dummy", + mqtt_client=unittest.mock.Mock(), mqtt_message_payload=message_payload, update_device_info=True, mqtt_topic_prefix="dummy", diff --git a/tests/test_switchbot_curtain_motor_position.py b/tests/test_switchbot_curtain_motor_position.py index e7ee6e3..0778bde 100644 --- a/tests/test_switchbot_curtain_motor_position.py +++ b/tests/test_switchbot_curtain_motor_position.py @@ -19,34 +19,34 @@ import logging import unittest.mock +import aiomqtt import _pytest.logging # pylint: disable=import-private-name; typing import pytest -from paho.mqtt.client import MQTTMessage # pylint: disable=import-private-name; internal from switchbot_mqtt._actors import _CurtainMotor -from switchbot_mqtt._actors.base import _MQTTCallbackUserdata # pylint: disable=protected-access +@pytest.mark.asyncio @pytest.mark.parametrize( ("topic", "payload", "expected_mac_address", "expected_position_percent"), [ ( - b"home/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent", + "home/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent", b"42", "aa:bb:cc:dd:ee:ff", 42, ), ( - b"home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent", + "home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent", b"0", "11:22:33:44:55:66", 0, ), ( - b"home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent", + "home/cover/switchbot-curtain/11:22:33:44:55:66/position/set-percent", b"100", "11:22:33:44:55:66", 100, @@ -54,27 +54,27 @@ ], ) @pytest.mark.parametrize("retry_count", (3, 42)) -def test__mqtt_set_position_callback( +async def test__mqtt_set_position_callback( caplog: _pytest.logging.LogCaptureFixture, - topic: bytes, + topic: str, payload: bytes, expected_mac_address: str, retry_count: int, expected_position_percent: int, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=retry_count, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="home/", + message = aiomqtt.Message( + topic=topic, payload=payload, qos=0, retain=False, mid=0, properties=None ) - message = MQTTMessage(topic=topic) - message.payload = payload with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.DEBUG): - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=retry_count, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="home/", ) device_init_mock.assert_called_once_with( mac=expected_mac_address, @@ -99,25 +99,28 @@ def test__mqtt_set_position_callback( ] -def test__mqtt_set_position_callback_ignore_retained( +@pytest.mark.asyncio +async def test__mqtt_set_position_callback_ignore_retained( caplog: _pytest.logging.LogCaptureFixture, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="whatever", - ) - message = MQTTMessage( - topic=b"homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent" + message = aiomqtt.Message( + topic="homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent", + payload=b"42", + qos=0, + retain=True, + mid=0, + properties=None, ) - message.payload = b"42" - message.retain = True with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.INFO): - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="whatever", ) device_init_mock.assert_not_called() assert caplog.record_tuples == [ @@ -130,22 +133,28 @@ def test__mqtt_set_position_callback_ignore_retained( ] -def test__mqtt_set_position_callback_unexpected_topic( +@pytest.mark.asyncio +async def test__mqtt_set_position_callback_unexpected_topic( caplog: _pytest.logging.LogCaptureFixture, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="", + message = aiomqtt.Message( + topic="switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set", + payload=b"42", + qos=0, + retain=False, + mid=0, + properties=None, ) - message = MQTTMessage(topic=b"switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set") - message.payload = b"42" with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.INFO): - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="", ) device_init_mock.assert_not_called() assert caplog.record_tuples == [ @@ -157,24 +166,28 @@ def test__mqtt_set_position_callback_unexpected_topic( ] -def test__mqtt_set_position_callback_invalid_mac_address( +@pytest.mark.asyncio +async def test__mqtt_set_position_callback_invalid_mac_address( caplog: _pytest.logging.LogCaptureFixture, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="tnatsissaemoh/", - ) - message = MQTTMessage( - topic=b"tnatsissaemoh/cover/switchbot-curtain/aa:bb:cc:dd:ee/position/set-percent" + message = aiomqtt.Message( + topic="tnatsissaemoh/cover/switchbot-curtain/aa:bb:cc:dd:ee/position/set-percent", + payload=b"42", + qos=0, + retain=False, + mid=0, + properties=None, ) - message.payload = b"42" with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.INFO): - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="tnatsissaemoh/", ) device_init_mock.assert_not_called() assert caplog.record_tuples == [ @@ -186,26 +199,30 @@ def test__mqtt_set_position_callback_invalid_mac_address( ] +@pytest.mark.asyncio @pytest.mark.parametrize("payload", [b"-1", b"123"]) -def test__mqtt_set_position_callback_invalid_position( +async def test__mqtt_set_position_callback_invalid_position( caplog: _pytest.logging.LogCaptureFixture, payload: bytes, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="homeassistant/", + message = aiomqtt.Message( + topic="homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent", + payload=payload, + qos=0, + retain=False, + mid=0, + properties=None, ) - message = MQTTMessage( - topic=b"homeassistant/cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent" - ) - message.payload = payload with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.INFO): - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="homeassistant/", ) device_init_mock.assert_called_once() device_init_mock().set_position.assert_not_called() @@ -218,26 +235,30 @@ def test__mqtt_set_position_callback_invalid_position( ] -def test__mqtt_set_position_callback_command_failed( +@pytest.mark.asyncio +async def test__mqtt_set_position_callback_command_failed( caplog: _pytest.logging.LogCaptureFixture, ) -> None: - callback_userdata = _MQTTCallbackUserdata( - retry_count=3, - device_passwords={}, - fetch_device_info=False, - mqtt_topic_prefix="", - ) - message = MQTTMessage( - topic=b"cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent" + message = aiomqtt.Message( + topic="cover/switchbot-curtain/aa:bb:cc:dd:ee:ff/position/set-percent", + payload=b"21", + qos=0, + retain=False, + mid=0, + properties=None, ) - message.payload = b"21" with unittest.mock.patch( "switchbot.SwitchbotCurtain" ) as device_init_mock, caplog.at_level(logging.INFO): device_init_mock().set_position.return_value = False device_init_mock.reset_mock() - _CurtainMotor._mqtt_set_position_callback( - mqtt_client="client dummy", userdata=callback_userdata, message=message + await _CurtainMotor._mqtt_set_position_callback( + mqtt_client=unittest.mock.Mock(), + message=message, + retry_count=3, + device_passwords={}, + fetch_device_info=False, + mqtt_topic_prefix="", ) device_init_mock.assert_called_once() device_init_mock().set_position.assert_called_with(21)