From 70910bc6a63e607332b4f12754ba470651eb878c Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 27 Oct 2024 15:56:18 -0400 Subject: [PATCH] Cleanly shut down the serial port on disconnect (#259) * Cleanly shut down the serial port on disconnect * Send `connection_lost` even if we do not have an open serial connection * Call `super().close()` in `SerialProtocol` * Use `self._transport.write` instead of `send_data` * Let zigpy handle flow control * Bump minimum zigpy version * Fix unit tests * Make `api` an async fixture to grab reference to loop early * Set default pytest-asyncio fixture loop scope * Fix unit test failing due to event loop caching issue in pytest-asyncio * Bring test coverage up --- pyproject.toml | 3 +- tests/test_api.py | 44 ++++++++++++++++++------ tests/test_application.py | 7 ++-- tests/test_uart.py | 13 +++++-- zigpy_deconz/api.py | 27 +++++++-------- zigpy_deconz/uart.py | 55 ++++++++++++------------------ zigpy_deconz/zigbee/application.py | 4 +-- 7 files changed, 87 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 39a8c86..cc25e77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ "voluptuous", - "zigpy>=0.68.0", + "zigpy>=0.70.0", 'async-timeout; python_version<"3.11"', ] @@ -47,6 +47,7 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.flake8] exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"] diff --git a/tests/test_api.py b/tests/test_api.py index c7e2fcc..b3fcd2f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -25,15 +25,23 @@ @pytest.fixture -def gateway(): +async def gateway(): return uart.Gateway(api=None) @pytest.fixture -def api(gateway, mock_command_rsp): +async def api(gateway, mock_command_rsp): + loop = asyncio.get_running_loop() + async def mock_connect(config, api): + transport = MagicMock() + transport.close = MagicMock( + side_effect=lambda: loop.call_soon(gateway.connection_lost, None) + ) + gateway._api = api - gateway.connection_made(MagicMock()) + gateway.connection_made(transport) + return gateway with patch("zigpy_deconz.uart.connect", side_effect=mock_connect): @@ -178,15 +186,33 @@ async def test_connect(api, mock_command_rsp): await api.connect() +async def test_connect_failure(api, mock_command_rsp): + transport = None + + def mock_version(*args, **kwargs): + nonlocal transport + transport = api._uart._transport + + raise asyncio.TimeoutError() + + with patch.object(api, "version", side_effect=mock_version): + # We connect but fail to probe + with pytest.raises(asyncio.TimeoutError): + await api.connect() + + assert api._uart is None + assert len(transport.close.mock_calls) == 1 + + async def test_close(api): await api.connect() uart = api._uart - uart.close = MagicMock(wraps=uart.close) + uart.disconnect = AsyncMock() - api.close() + await api.disconnect() assert api._uart is None - assert uart.close.call_count == 1 + assert uart.disconnect.call_count == 1 def test_commands(): @@ -898,11 +924,9 @@ async def test_data_poller(api, mock_command_rsp): # The task is cancelled on close task = api._data_poller_task - api.close() + await api.disconnect() assert api._data_poller_task is None - - if sys.version_info >= (3, 11): - assert task.cancelling() + assert task.done() async def test_get_device_state(api, mock_command_rsp): diff --git a/tests/test_application.py b/tests/test_application.py index 59c480a..f51794d 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -187,6 +187,7 @@ async def test_connect_failure(app): with patch.object(application, "Deconz") as api_mock: api = api_mock.return_value = MagicMock() api.connect = AsyncMock(side_effect=RuntimeError("Broken")) + api.disconnect = AsyncMock() app._api = None @@ -195,16 +196,16 @@ async def test_connect_failure(app): assert app._api is None api.connect.assert_called_once() - api.close.assert_called_once() + api.disconnect.assert_called_once() async def test_disconnect(app): - api_close = app._api.close = MagicMock() + api_disconnect = app._api.disconnect = AsyncMock() await app.disconnect() assert app._api is None - assert api_close.call_count == 1 + assert api_disconnect.call_count == 1 async def test_disconnect_no_api(app): diff --git a/tests/test_uart.py b/tests/test_uart.py index 432c48f..17fc320 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -4,7 +4,11 @@ from unittest import mock import pytest -from zigpy.config import CONF_DEVICE_BAUDRATE, CONF_DEVICE_PATH +from zigpy.config import ( + CONF_DEVICE_BAUDRATE, + CONF_DEVICE_FLOW_CONTROL, + CONF_DEVICE_PATH, +) import zigpy.serial from zigpy_deconz import uart @@ -28,7 +32,12 @@ async def mock_conn(loop, protocol_factory, **kwargs): monkeypatch.setattr(zigpy.serial, "create_serial_connection", mock_conn) await uart.connect( - {CONF_DEVICE_PATH: "/dev/null", CONF_DEVICE_BAUDRATE: 115200}, api + { + CONF_DEVICE_PATH: "/dev/null", + CONF_DEVICE_BAUDRATE: 115200, + CONF_DEVICE_FLOW_CONTROL: None, + }, + api, ) diff --git a/zigpy_deconz/api.py b/zigpy_deconz/api.py index e91991a..2a2d77d 100644 --- a/zigpy_deconz/api.py +++ b/zigpy_deconz/api.py @@ -14,7 +14,6 @@ else: from asyncio import timeout as asyncio_timeout # pragma: no cover -from zigpy.config import CONF_DEVICE_PATH from zigpy.datastructures import PriorityLock from zigpy.types import ( APSStatus, @@ -461,37 +460,37 @@ def protocol_version(self) -> int: async def connect(self) -> None: assert self._uart is None + self._uart = await zigpy_deconz.uart.connect(self._config, self) - await self.version() + try: + await self.version() + device_state_rsp = await self.send_command(CommandId.device_state) + except Exception: + await self.disconnect() + self._uart = None + raise - device_state_rsp = await self.send_command(CommandId.device_state) self._device_state = device_state_rsp["device_state"] self._data_poller_task = asyncio.create_task(self._data_poller()) - def connection_lost(self, exc: Exception) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Lost serial connection.""" - LOGGER.debug( - "Serial %r connection lost unexpectedly: %r", - self._config[CONF_DEVICE_PATH], - exc, - ) - if self._app is not None: self._app.connection_lost(exc) - def close(self): - self._app = None - + async def disconnect(self): if self._data_poller_task is not None: self._data_poller_task.cancel() self._data_poller_task = None if self._uart is not None: - self._uart.close() + await self._uart.disconnect() self._uart = None + self._app = None + def _get_command_priority(self, command: Command) -> int: return { # The watchdog is fed using `write_parameter` and `get_device_state` so they diff --git a/zigpy_deconz/uart.py b/zigpy_deconz/uart.py index f555787..3481fb6 100644 --- a/zigpy_deconz/uart.py +++ b/zigpy_deconz/uart.py @@ -1,9 +1,11 @@ """Uart module.""" +from __future__ import annotations + import asyncio import binascii import logging -from typing import Callable, Dict +from typing import Any, Callable import zigpy.config import zigpy.serial @@ -11,49 +13,38 @@ LOGGER = logging.getLogger(__name__) -class Gateway(asyncio.Protocol): +class Gateway(zigpy.serial.SerialProtocol): END = b"\xC0" ESC = b"\xDB" ESC_END = b"\xDC" ESC_ESC = b"\xDD" - def __init__(self, api, connected_future=None): + def __init__(self, api): """Initialize instance of the UART gateway.""" - + super().__init__() self._api = api - self._buffer = b"" - self._connected_future = connected_future - self._transport = None - def connection_lost(self, exc) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Port was closed expectedly or unexpectedly.""" + super().connection_lost(exc) - if exc is not None: - LOGGER.warning("Lost connection: %r", exc, exc_info=exc) - - self._api.connection_lost(exc) - - def connection_made(self, transport): - """Call this when the uart connection is established.""" - - LOGGER.debug("Connection made") - self._transport = transport - if self._connected_future and not self._connected_future.done(): - self._connected_future.set_result(True) + if self._api is not None: + self._api.connection_lost(exc) def close(self): - self._transport.close() + super().close() + self._api = None - def send(self, data): + def send(self, data: bytes) -> None: """Send data, taking care of escaping and framing.""" - LOGGER.debug("Send: %s", binascii.hexlify(data).decode()) checksum = bytes(self._checksum(data)) frame = self._escape(data + checksum) self._transport.write(self.END + frame + self.END) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Handle data received from the uart.""" - self._buffer += data + super().data_received(data) + while self._buffer: end = self._buffer.find(self.END) if end < 0: @@ -121,23 +112,19 @@ def _checksum(self, data): return bytes(ret) -async def connect(config: Dict[str, any], api: Callable) -> Gateway: - loop = asyncio.get_running_loop() - connected_future = loop.create_future() - protocol = Gateway(api, connected_future) +async def connect(config: dict[str, Any], api: Callable) -> Gateway: + protocol = Gateway(api) LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH]) _, protocol = await zigpy.serial.create_serial_connection( - loop=loop, + loop=asyncio.get_running_loop(), protocol_factory=lambda: protocol, url=config[zigpy.config.CONF_DEVICE_PATH], baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE], - xonxoff=False, + flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL], ) - await connected_future - - LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH]) + await protocol.wait_until_connected() return protocol diff --git a/zigpy_deconz/zigbee/application.py b/zigpy_deconz/zigbee/application.py index 8ad9fa6..afc796e 100644 --- a/zigpy_deconz/zigbee/application.py +++ b/zigpy_deconz/zigbee/application.py @@ -97,7 +97,7 @@ async def connect(self): try: await api.connect() except Exception: - api.close() + await api.disconnect() raise self._api = api @@ -109,7 +109,7 @@ async def disconnect(self): self._delayed_neighbor_scan_task = None if self._api is not None: - self._api.close() + await self._api.disconnect() self._api = None async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):