Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrites RedisChannelLayer.receive specific channel section #165

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 89 additions & 111 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,80 @@ class UnsupportedRedis(Exception):
pass


class ReceiveBuffer:
"""
Receive buffer

It manages waiters and buffers messages for all specific channels under the same 'real channel'
Also manages the receive loop for the 'real channel'
"""

def __init__(self, receive_single, real_channel):
self.loop = None
self.real_channel = real_channel
self.receive_single = receive_single
self.getters = collections.defaultdict(collections.deque)
self.buffers = collections.defaultdict(collections.deque)
self.receiver = None

def __bool__(self):
return bool(self.getters)

def get(self, channel):
"""
:param channel: name of the channel
:return: Future for the next message on channel
"""
assert channel.startswith(
self.real_channel
), "channel not managed by this buffer"
getter = self.loop.create_future()

if channel in self.buffers:
getter.set_result(self.buffers[channel].popleft())
if not self.buffers[channel]:
del self.buffers[channel]
else:
getter.channel = channel
getter.add_done_callback(self._getter_done_prematurely)
self.getters[channel].append(getter)

# ensure receiver is running
if not self.receiver:
self.receiver = asyncio.ensure_future(self.receiver_factory())
return getter

def _getter_done_prematurely(self, getter):
channel = getter.channel
self.getters[channel].remove(getter)
if not self.getters[channel]:
del self.getters[channel]
if not self and self.receiver:
self.receiver.cancel()

def put(self, channel, message):
if channel in self.getters:
getter = self.getters[channel].popleft()
getter.remove_done_callback(self._getter_done_prematurely)
if not self.getters[channel]:
del self.getters[channel]
getter.set_result(message)
else:
self.buffers[channel].append(message)

async def receiver_factory(self):
try:
while self:
message_channel, message = await self.receive_single(self.real_channel)
if type(message_channel) is list:
for chan in message_channel:
self.put(chan, message)
else:
self.put(message_channel, message)
finally:
self.receiver = None


class RedisChannelLayer(BaseChannelLayer):
"""
Redis channel layer.
Expand Down Expand Up @@ -209,14 +283,8 @@ def __init__(
)
# Set up any encryption objects
self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
self.receive_lock = None
# Event loop they are trying to receive on
self.receive_event_loop = None
# Buffered messages by process-local channel name
self.receive_buffer = collections.defaultdict(asyncio.Queue)
self.receive_buffers = {}
# Detached channel cleanup tasks
self.receive_cleaners = []
# Per-channel cleanup locks to prevent a receive starting and moving
Expand Down Expand Up @@ -352,110 +420,20 @@ async def receive(self, channel):
), "Wrong client prefix"
# Enter receiving section
loop = asyncio.get_event_loop()
self.receive_count += 1
try:
if self.receive_count == 1:
# If we're the first coroutine in, create the receive lock!
self.receive_lock = asyncio.Lock()
self.receive_event_loop = loop
else:
# Otherwise, check our event loop matches
if self.receive_event_loop != loop:
raise RuntimeError(
"Two event loops are trying to receive() on one channel layer at once!"
)

# Wait for our message to appear
message = None
while self.receive_buffer[channel].empty():
tasks = [
self.receive_lock.acquire(),
self.receive_buffer[channel].get(),
]
tasks = [asyncio.ensure_future(task) for task in tasks]
try:
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
# Cancel all pending tasks.
task.cancel()
except asyncio.CancelledError:
# Ensure all tasks are cancelled if we are cancelled.
# Also see: https://bugs.python.org/issue23859
del self.receive_buffer[channel]
for task in tasks:
if not task.cancel():
assert task.done()
if task.result() is True:
self.receive_lock.release()

raise

message, token, exception = None, None, None
for task in done:
try:
result = task.result()
except Exception as error: # NOQA
# We should not propagate exceptions immediately as otherwise this may cause
# the lock to be held and never be released.
exception = error
continue

if result is True:
token = result
else:
assert isinstance(result, dict)
message = result

if message or exception:
if token:
# We will not be receving as we already have the message.
self.receive_lock.release()

if exception:
raise exception
else:
break
else:
assert token

# We hold the receive lock, receive and then release it.
try:
# There is no interruption point from when the message is
# unpacked in receive_single to when we get back here, so
# the following lines are essentially atomic.
message_channel, message = await self.receive_single(
real_channel
)
if type(message_channel) is list:
for chan in message_channel:
self.receive_buffer[chan].put_nowait(message)
else:
self.receive_buffer[message_channel].put_nowait(message)
message = None
except:
del self.receive_buffer[channel]
raise
finally:
self.receive_lock.release()

# We know there's a message available, because there
# couldn't have been any interruption between empty() and here
if message is None:
message = self.receive_buffer[channel].get_nowait()

if self.receive_buffer[channel].empty():
del self.receive_buffer[channel]
return message

finally:
self.receive_count -= 1
# If we were the last out, drop the receive lock
if self.receive_count == 0:
assert not self.receive_lock.locked()
self.receive_lock = None
self.receive_event_loop = None
if real_channel not in self.receive_buffers:
self.receive_buffers[real_channel] = ReceiveBuffer(
self.receive_single, real_channel
)
receive_buffer = self.receive_buffers[real_channel]

# Check our event loop matches
if receive_buffer.loop != loop and receive_buffer.receiver:
raise RuntimeError(
"Two event loops are trying to receive() on one channel layer at once!"
)
else:
receive_buffer.loop = loop
return await receive_buffer.get(channel)
else:
# Do a plain direct receive
return (await self.receive_single(channel))[1]
Expand Down
67 changes: 66 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from async_generator import async_generator, yield_

from asgiref.sync import async_to_sync
from channels_redis.core import ChannelFull, RedisChannelLayer
from channels_redis.core import ChannelFull, ReceiveBuffer, RedisChannelLayer

TEST_HOSTS = [("localhost", 6379)]

Expand Down Expand Up @@ -343,3 +343,68 @@ async def test_receive_cancel(channel_layer):
await asyncio.wait_for(task, None)
except asyncio.CancelledError:
pass


@pytest.mark.asyncio
async def test_receive_multiple_specific_prefixes(channel_layer):
"""
Makes sure we receive on multiple real channels
"""
channel_layer = RedisChannelLayer(capacity=10)
channel1 = await channel_layer.new_channel()
channel2 = await channel_layer.new_channel(prefix="thing")
r1, _, r2 = tasks = [
asyncio.ensure_future(x)
for x in (
channel_layer.receive(channel1),
channel_layer.send(channel2, {"type": "message"}),
channel_layer.receive(channel2),
)
]
await asyncio.wait(tasks, timeout=0.5)

assert not r1.done()
assert r2.done() and r2.result()["type"] == "message"
r1.cancel()


@pytest.mark.asyncio
async def test_buffer_wrong_channel(channel_layer):
async def dummy_receive(channel):
return channel, {"type": "message"}

buffer = ReceiveBuffer(dummy_receive, "whatever!")
buffer.loop = asyncio.get_event_loop()
with pytest.raises(AssertionError):
buffer.get("wrong!13685sjmh")


@pytest.mark.asyncio
async def test_buffer_receiver_stopped(channel_layer):
async def dummy_receive(channel):
return "whatever!meh", {"type": "message"}

buffer = ReceiveBuffer(dummy_receive, "whatever!")
buffer.loop = asyncio.get_event_loop()

await buffer.get("whatever!meh")
assert buffer.receiver is None


@pytest.mark.asyncio
async def test_buffer_receiver_canceled(channel_layer):
async def dummy_receive(channel):
await asyncio.sleep(2)
return "whatever!meh", {"type": "message"}

buffer = ReceiveBuffer(dummy_receive, "whatever!")
buffer.loop = asyncio.get_event_loop()

get1 = buffer.get("whatever!meh")
assert buffer.receiver is not None
get2 = buffer.get("whatever!meh2")
get1.cancel()
assert buffer.receiver is not None
get2.cancel()
await asyncio.sleep(0.1)
assert buffer.receiver is None