Skip to content

Commit

Permalink
Wrap resourse error in channels, add 2 tests for InMemoryChannelLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
dolamroth committed Feb 7, 2024
1 parent a258927 commit 1b9b59b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
9 changes: 8 additions & 1 deletion starlette_web/common/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MemoryObjectSendStream,
EndOfStream,
ClosedResourceError,
BrokenResourceError,
)

from starlette_web.common.channels.layers.base import BaseChannelLayer
Expand Down Expand Up @@ -50,6 +51,12 @@ async def disconnect(self) -> None:
await self._channel_layer.disconnect()

async def _listener(self) -> None:
async def _safe_send(_send_stream: MemoryObjectSendStream, _event: Event):
try:
await _send_stream.send(_event)
except (BrokenResourceError, ClosedResourceError):
pass

async with anyio.create_task_group() as task_group:
while True:
try:
Expand All @@ -61,7 +68,7 @@ async def _listener(self) -> None:
subscribers_list = list(self._subscribers.get(event.group, []))

for send_stream in subscribers_list:
task_group.start_soon(send_stream.send, event)
task_group.start_soon(_safe_send, send_stream, event)

async with self._manager_lock:
for group in self._subscribers.keys():
Expand Down
72 changes: 72 additions & 0 deletions starlette_web/tests/contrib/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,75 @@ async def subscriber_task(channel: Channel):

res = await_(task_coroutine())
assert res == [0, 0]

def test_multiple_same_subscribers_inmemorychannellayer(self):
async def task_coroutine():
_result = []

async def publisher_task(channel: Channel):
await anyio.sleep(0.1)
for _ in range(10):
await channel.publish("test_group", "Message")

async def subscriber_task(channel: Channel, _res: list, break_after: int):
async with channel.subscribe("test_group") as subscriber:
_message_counter = 0
async for message in subscriber:
if _message_counter >= break_after:
break
_res.append(message)
_message_counter += 1

async with Channel(InMemoryChannelLayer()) as channels:
async with anyio.create_task_group() as task_group:
task_group.cancel_scope.deadline = anyio.current_time() + 5.0
task_group.start_soon(publisher_task, channels)
task_group.start_soon(subscriber_task, channels, _result, 5)
task_group.start_soon(subscriber_task, channels, _result, 2)
task_group.start_soon(subscriber_task, channels, _result, 1)

return _result

res = await_(task_coroutine())
assert len(res) == 8

def test_publish_in_subscribe_inmemorychannellayer(self):
async def task_coroutine():
_result = []

async def pipeline_1(channel: Channel):
await anyio.sleep(0.1)
for _ in range(10):
await channel.publish("topic_1", "Message")

async def pipeline_2(channel: Channel):
async with channel.subscribe("topic_1") as subscriber:
_messages_count = 0
async for event in subscriber:
await channel.publish("topic_2", event.message)
_messages_count += 1
if _messages_count >= 10:
break

async def pipeline_3(channel: Channel, _res):
async with channel.subscribe("topic_2") as subscriber:
_messages_count = 0
async for event in subscriber:
_res.append(event)
# This topic is not listened by anyone
await channel.publish("topic_3", event.message)
_messages_count += 1
if _messages_count >= 10:
break

async with Channel(InMemoryChannelLayer()) as channels:
async with anyio.create_task_group() as task_group:
task_group.cancel_scope.deadline = anyio.current_time() + 3.0
task_group.start_soon(pipeline_1, channels)
task_group.start_soon(pipeline_2, channels)
task_group.start_soon(pipeline_3, channels, _result)

return _result

res = await_(task_coroutine())
assert len(res) == 10

0 comments on commit 1b9b59b

Please sign in to comment.