diff --git a/aiomqtt/client.py b/aiomqtt/client.py index 36bd000..fd4f90f 100644 --- a/aiomqtt/client.py +++ b/aiomqtt/client.py @@ -1036,7 +1036,11 @@ async def __aenter__(self) -> Client: msg = "Does not support reentrant" raise MqttReentrantError(msg) await self._lock.acquire() - await self.connect() + try: + await self.connect() + except Exception: + self._lock.release() + raise return self async def __aexit__( diff --git a/tests/test_client.py b/tests/test_client.py index 5c63cff..f588943 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,7 +13,7 @@ from anyio.abc import TaskStatus from aiomqtt import Client, ProtocolVersion, TLSParameters, Topic, Wildcard, Will -from aiomqtt.error import MqttReentrantError +from aiomqtt.error import MqttError, MqttReentrantError from aiomqtt.types import PayloadType pytestmark = pytest.mark.anyio @@ -536,3 +536,12 @@ async def test_client_connecting_disconnected_done() -> None: client._disconnected.set_result(None) await client.connect() await client.disconnect() + + +@pytest.mark.network +async def test_client_aenter_connect_error_lock_release() -> None: + client = Client(hostname="aenter_connect_error_lock_release") + with pytest.raises(MqttError): # noqa: PT012 + async with client: + ... + assert not client._lock.locked()