Skip to content

Commit

Permalink
refactor: Remove deprecated un/filtered_messages generators
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Jan 3, 2024
1 parent b668392 commit bce8236
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 175 deletions.
101 changes: 3 additions & 98 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class Will:
properties: mqtt.Properties | None = None


# TLS set parameter class
@dataclass(frozen=True)
class TLSParameters:
ca_certs: str | None = None
Expand All @@ -81,7 +80,6 @@ class TLSParameters:
keyfile_password: str | None = None


# Proxy parameters class
class ProxySettings:
def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -132,7 +130,8 @@ class Wildcard:
"""MQTT wildcard that can be subscribed to, but not published to.
A wildcard is similar to a topic, but can optionally contain ``+`` and ``#``
placeholders.
placeholders. You can access the ``value`` attribute directly to perform ``str``
operations on a wildcard.
Args:
value: The wildcard string.
Expand Down Expand Up @@ -392,6 +391,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
[mqtt.Client, Any, mqtt.MQTTMessage], None
] | None = None

# TODO(felix): This does not seem to be used anywhere. Remove?
self._outgoing_calls_sem: asyncio.Semaphore | None
if max_concurrent_outgoing_calls is not None:
self._outgoing_calls_sem = asyncio.Semaphore(max_concurrent_outgoing_calls)
Expand Down Expand Up @@ -605,50 +605,6 @@ async def publish( # noqa: PLR0913
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)

@asynccontextmanager
async def filtered_messages(
self, topic_filter: str, *, queue_maxsize: int = 0
) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
"""Return async generator of messages that match the given filter."""
self._logger.warning(
"filtered_messages() is deprecated and will be removed in a future version."
" Use messages() together with Topic.matches() instead."
)
callback, generator = self._deprecated_callback_and_generator(
log_context=f'topic_filter="{topic_filter}"', queue_maxsize=queue_maxsize
)
try:
self._client.message_callback_add(topic_filter, callback)
# Back to the caller (run whatever is inside the with statement)
yield generator
finally:
# We are exiting the with statement. Remove the topic filter.
self._client.message_callback_remove(topic_filter)

@asynccontextmanager
async def unfiltered_messages(
self, *, queue_maxsize: int = 0
) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
"""Return async generator of all messages that are not caught in filters."""
self._logger.warning(
"unfiltered_messages() is deprecated and will be removed in a future"
" version. Use messages() instead."
)
# Early out
if self._unfiltered_messages_callback is not None:
msg = "Only a single unfiltered_messages generator can be used at a time"
raise RuntimeError(msg)
callback, generator = self._deprecated_callback_and_generator(
log_context="unfiltered", queue_maxsize=queue_maxsize
)
try:
self._unfiltered_messages_callback = callback
# Back to the caller (run whatever is inside the with statement)
yield generator
finally:
# We are exiting the with statement. Unset the callback.
self._unfiltered_messages_callback = None

@asynccontextmanager
async def messages(
self,
Expand Down Expand Up @@ -682,57 +638,6 @@ async def messages(
# We are exiting the with statement. Remove the callback from the list.
self._on_message_callbacks.remove(callback)

def _deprecated_callback_and_generator(
self, *, log_context: str, queue_maxsize: int = 0
) -> tuple[
Callable[[mqtt.Client, Any, mqtt.MQTTMessage], None],
AsyncGenerator[mqtt.MQTTMessage, None],
]:
# Queue to hold the incoming messages
messages: asyncio.Queue[mqtt.MQTTMessage] = asyncio.Queue(maxsize=queue_maxsize)

# Callback for the underlying API
def _put_in_queue(
client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage
) -> None:
try:
messages.put_nowait(message)
except asyncio.QueueFull:
self._logger.warning(
"[%s] Message queue is full. Discarding message.", log_context
)

# The generator that we give to the caller
async def _message_generator() -> AsyncGenerator[mqtt.MQTTMessage, None]:
# Forward all messages from the queue
while True:
# Wait until we either:
# 1. Receive a message
# 2. Disconnect from the broker
get: asyncio.Task[mqtt.MQTTMessage] = self._loop.create_task(
messages.get()
)
try:
done, _ = await asyncio.wait(
(get, self._disconnected), return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
# If the asyncio.wait is cancelled, we must make sure
# to also cancel the underlying tasks.
get.cancel()
raise
if get in done:
# We received a message. Return the result.
yield get.result()
else:
# We got disconnected from the broker. Cancel the "get" task.
get.cancel()
# Stop the generator with the following exception
msg = "Disconnected during message iteration"
raise MqttError(msg)

return _put_in_queue, _message_generator()

def _callback_and_generator(
self,
*,
Expand Down
109 changes: 32 additions & 77 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,94 +91,49 @@ async def test_topic_matches() -> None:
@pytest.mark.network
async def test_multiple_messages_generators() -> None:
"""Test that multiple Client.messages() generators can be used at the same time."""
topic = TOPIC_PREFIX + "multiple_messages_generators"
topic = TOPIC_PREFIX + "test_multiple_messages_generators"

async def handler(tg: anyio.abc.TaskGroup) -> None:
async def handle(tg: anyio.abc.TaskGroup) -> None:
async with client.messages() as messages:
async for message in messages:
assert str(message.topic) == topic
assert message.topic.value == topic
tg.cancel_scope.cancel()

async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:
await client.subscribe(topic)
tg.start_soon(handler, tg)
tg.start_soon(handler, tg)
tg.start_soon(handle, tg)
tg.start_soon(handle, tg)
await anyio.wait_all_tasks_blocked()
await client.publish(topic)


@pytest.mark.network
async def test_client_filtered_messages() -> None:
topic_header = TOPIC_PREFIX + "filtered_messages/"
good_topic = topic_header + "good"
bad_topic = topic_header + "bad"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(good_topic) as messages:
async for message in messages:
assert message.topic == good_topic
tg.cancel_scope.cancel()

async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:
await client.subscribe(topic_header + "#")
tg.start_soon(handle_messages, tg)
await anyio.wait_all_tasks_blocked()
await client.publish(bad_topic, 2)
await client.publish(good_topic, 2)


@pytest.mark.network
async def test_client_unfiltered_messages() -> None:
topic_header = TOPIC_PREFIX + "unfiltered_messages/"
topic_filtered = topic_header + "filtered"
topic_unfiltered = topic_header + "unfiltered"

async def handle_unfiltered_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.unfiltered_messages() as messages:
async for message in messages:
assert message.topic == topic_unfiltered
tg.cancel_scope.cancel()

async def handle_filtered_messages() -> None:
async with client.filtered_messages(topic_filtered) as messages:
async for message in messages:
assert message.topic == topic_filtered

async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:
await client.subscribe(topic_header + "#")
tg.start_soon(handle_filtered_messages)
tg.start_soon(handle_unfiltered_messages, tg)
await anyio.wait_all_tasks_blocked()
await client.publish(topic_filtered, 2)
await client.publish(topic_unfiltered, 2)


@pytest.mark.network
async def test_client_unsubscribe() -> None:
topic_header = TOPIC_PREFIX + "unsubscribe/"
topic1 = topic_header + "1"
topic2 = topic_header + "2"
"""Test that messages are no longer received after unsubscribing from a topic."""
topic_1 = TOPIC_PREFIX + "test_client_unsubscribe/1"
topic_2 = TOPIC_PREFIX + "test_client_unsubscribe/2"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.unfiltered_messages() as messages:
async def handle(tg: anyio.abc.TaskGroup) -> None:
async with client.messages() as messages:
is_first_message = True
async for message in messages:
if is_first_message:
assert message.topic == topic1
assert message.topic.value == topic_1
is_first_message = False
else:
assert message.topic == topic2
assert message.topic.value == topic_2
tg.cancel_scope.cancel()

async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:
await client.subscribe(topic1)
await client.subscribe(topic2)
tg.start_soon(handle_messages, tg)
await client.subscribe(topic_1)
await client.subscribe(topic_2)
tg.start_soon(handle, tg)
await anyio.wait_all_tasks_blocked()
await client.publish(topic1, 2)
await client.unsubscribe(topic1)
await client.publish(topic1, 2)
await client.publish(topic2, 2)
await client.publish(topic_1, None)
await client.unsubscribe(topic_1)
await client.publish(topic_1, None)
# Test that other subscriptions still receive messages
await client.publish(topic_2, None)


@pytest.mark.parametrize(
Expand All @@ -192,17 +147,17 @@ async def test_client_id(protocol: ProtocolVersion, length: int) -> None:

@pytest.mark.network
async def test_client_will() -> None:
topic = TOPIC_PREFIX + "will"
topic = TOPIC_PREFIX + "test_client_will"
event = anyio.Event()

async def launch_client() -> None:
with anyio.CancelScope(shield=True) as cs:
async with Client(HOSTNAME) as client:
await client.subscribe(topic)
event.set()
async with client.filtered_messages(topic) as messages:
async with client.messages() as messages:
async for message in messages:
assert message.topic == topic
assert message.topic.value == topic
cs.cancel()

async with anyio.create_task_group() as tg:
Expand All @@ -214,12 +169,12 @@ async def launch_client() -> None:

@pytest.mark.network
async def test_client_tls_context() -> None:
topic = TOPIC_PREFIX + "tls_context"
topic = TOPIC_PREFIX + "test_client_tls_context"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
async with client.messages() as messages:
async for message in messages:
assert message.topic == topic
assert message.topic.value == topic
tg.cancel_scope.cancel()

async with Client(
Expand All @@ -238,9 +193,9 @@ async def test_client_tls_params() -> None:
topic = TOPIC_PREFIX + "tls_params"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
async with client.messages() as messages:
async for message in messages:
assert message.topic == topic
assert message.topic.value == topic
tg.cancel_scope.cancel()

async with Client(
Expand All @@ -261,9 +216,9 @@ async def test_client_username_password() -> None:
topic = TOPIC_PREFIX + "username_password"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
async with client.messages() as messages:
async for message in messages:
assert message.topic == topic
assert message.topic.value == topic
tg.cancel_scope.cancel()

async with Client(
Expand Down Expand Up @@ -335,9 +290,9 @@ async def test_client_websockets() -> None:
topic = TOPIC_PREFIX + "websockets"

async def handle_messages(tg: anyio.abc.TaskGroup) -> None:
async with client.filtered_messages(topic) as messages:
async with client.messages() as messages:
async for message in messages:
assert message.topic == topic
assert message.topic.value == topic
tg.cancel_scope.cancel()

async with Client(
Expand Down

0 comments on commit bce8236

Please sign in to comment.