Skip to content

Commit

Permalink
pub: one channel; sub: channel list + reserve-channel
Browse files Browse the repository at this point in the history
  • Loading branch information
ric-evans committed Dec 1, 2023
1 parent 483485f commit b25345e
Showing 1 changed file with 105 additions and 84 deletions.
189 changes: 105 additions & 84 deletions mqclient/broker_clients/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ def __init__(

self.queue = queue
self.connection: Optional[pika.BlockingConnection] = None
self.channels: List[pika.adapters.blocking_connection.BlockingChannel] = []
self._next_channel_number = 1 # must start at 1 (not 0)

self._next_channel_number = 1 # must start at 1

def add_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
def open_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
"""Add a channel for the connection and configure."""
LOGGER.info(f"Adding channel to connection for '{self.queue=}'")
if not self.connection:
Expand All @@ -134,26 +132,27 @@ def add_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
queue=self.queue, durable=True, arguments={"x-queue-type": "quorum"}
)

self.channels.append(channel)
LOGGER.info(f"Added channel '{channel.channel_number}': {self.channels}")
# self.channels.append(channel)
LOGGER.info(f"Opened channel '{channel.channel_number}'")
return channel

async def connect(self) -> None:
"""Set up connection and channel."""
async def connect(self) -> pika.adapters.blocking_connection.BlockingChannel:
"""Set up connection and open 1 channel."""
await super().connect()
LOGGER.info(f"Connecting with parameters={self.parameters}")

self.connection = pika.BlockingConnection(self.parameters)
channel = self.add_channel()
if not channel or not self.channels:
channel = self.open_channel()
if not channel:
raise ConnectingFailedException("Channel was not connected")
return channel

async def close(self) -> None:
"""Close connection."""
await super().close()

if not self.channels:
raise ClosingFailedException("No channel to close.")
# if not self.channels:
# raise ClosingFailedException("No channel to close.")
if not self.connection:
raise ClosingFailedException("No connection to close.")
if self.connection.is_closed:
Expand All @@ -166,11 +165,6 @@ async def close(self) -> None:
except Exception as e:
raise ClosingFailedException() from e

for channel in self.channels:
if channel.is_open:
LOGGER.warning("Channel remains open after connection close.")
self.channels = []


class RabbitMQPub(RabbitMQ, Pub):
"""Wrapper around queue with delivery-confirm mode in the channel.
Expand All @@ -188,13 +182,14 @@ def __init__(
) -> None:
LOGGER.debug(f"{log_msgs.INIT_PUB} ({address}; {name})")
super().__init__(address, name, auth_token)
self.channel: pika.adapters.blocking_connection.BlockingChannel = []

def add_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
def open_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
"""Add a channel for the connection and configure."""
if self.channels:
if self.channel:
raise MQClientException("RabbitMQPub instance can only have one channel")

channel = super().add_channel()
channel = super().open_channel()
channel.confirm_delivery()
return channel

Expand All @@ -204,7 +199,7 @@ async def connect(self) -> None:
Turn on delivery confirmations.
"""
LOGGER.debug(log_msgs.CONNECTING_PUB)
await super().connect()
self.channel = await super().connect()
LOGGER.debug(log_msgs.CONNECTED_PUB)

async def close(self) -> None:
Expand All @@ -229,16 +224,15 @@ async def send_message(
RawQueue: queue
"""
LOGGER.debug(log_msgs.SENDING_MESSAGE)
if not self.channels:
if not self.channel:
raise MQClientException("queue is not connected")

def _send_msg():
# use wrapper function so connection references can be updated by reconnects
if not self.channels:
if not self.channel:
raise MQClientException("queue is not connected")
channel = self.channels[0]
LOGGER.debug(f"sending on channel: {channel.channel_number}")
return channel.basic_publish(
LOGGER.debug(f"sending on channel: {self.channel.channel_number}")
return self.channel.basic_publish(
exchange="",
routing_key=self.queue,
body=msg,
Expand Down Expand Up @@ -282,12 +276,19 @@ def __init__(
self.consumer_id = None
self.prefetch = prefetch

def add_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
self.active_channels: List[
pika.adapters.blocking_connection.BlockingChannel
] = []
self.reserve_channel: Optional[
pika.adapters.blocking_connection.BlockingChannel
] = None

def open_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:
"""Add a channel for the connection and configure.
Turn on prefetching.
"""
channel = super().add_channel()
channel = super().open_channel()

channel.basic_qos(prefetch_count=max(self.prefetch, 1))
# Setting the value to 0 lets the consumer drain the entire queue.
Expand All @@ -303,22 +304,11 @@ def add_channel(self) -> pika.adapters.blocking_connection.BlockingChannel:

return channel

def remove_channel(
self,
channel: pika.adapters.blocking_connection.BlockingChannel,
) -> None:
"""Remove the channel from the channel list, and close.
If there are any ack-pending messages, those will fail.
"""
self.channels.remove(channel)
if channel.is_open:
channel.close()

async def connect(self) -> None:
"""Set up connection, channel, and queue."""
LOGGER.debug(log_msgs.CONNECTING_SUB)
await super().connect()
channel = await super().connect()
self.active_channels.append(channel)
LOGGER.debug(log_msgs.CONNECTED_SUB)

async def close(self) -> None:
Expand All @@ -329,6 +319,13 @@ async def close(self) -> None:
"""
LOGGER.debug(log_msgs.CLOSING_SUB)
await super().close()
for channel in self.active_channels + [self.reserve_channel]:
if channel.is_open:
LOGGER.warning(
f"Channel remains open after connection close: {channel.channel_number}."
)
self.active_channels = []
self.reserve_channel = None
LOGGER.debug(log_msgs.CLOSED_SUB)

@staticmethod
Expand All @@ -355,12 +352,12 @@ async def _iter_messages(
retries: int,
retry_delay: float,
) -> AsyncIterator[Optional[Message]]:
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")

def _get_msg():
# use wrapper function so connection references can be updated by reconnects
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")
LOGGER.debug(f"consuming on channel: {channel.channel_number}")
try:
Expand All @@ -376,18 +373,20 @@ def _get_msg():
except StopIteration:
return (None, None, None)

def infinite_loop_over_channels() -> (
Iterator[pika.adapters.blocking_connection.BlockingChannel]
):
# this allows self.channels to be updated,
# updates are reflected on outer-loop
# itertools.cycle() does not allow updates
while True:
yield from self.channels

inf_channels_gen = infinite_loop_over_channels()
channel = next(inf_channels_gen) # always called manually
n_nonempty_channels_remaining = len(self.channels) # assume all are non-empty
# def infinite_loop_over_channels() -> (
# Iterator[pika.adapters.blocking_connection.BlockingChannel]
# ):
# # this allows self.channels to be updated,
# # updates are reflected on outer-loop
# # itertools.cycle() does not allow updates
# while True:
# yield from self.channels

# inf_channels_gen = infinite_loop_over_channels()
# channel = next(inf_channels_gen) # always called manually
# n_nonempty_channels_remaining = len(self.channels) # assume all are non-empty
remaining_channels = self.active_channels # start with all
channel = remaining_channels[0] # TODO - use by priority?

while True:
try:
Expand Down Expand Up @@ -416,37 +415,57 @@ def infinite_loop_over_channels() -> (
channel.channel_number,
):
LOGGER.debug(f"{log_msgs.GETMSG_RECEIVED_MESSAGE} ({msg.msg_id!r}).")
n_nonempty_channels_remaining = len(self.channels) # reset!
# n_nonempty_channels_remaining = len(self.channels) # reset!
if channel == self.reserve_channel:
# if this was the reserve channel, move it to active channels
self.active_channels.append(self.reserve_channel)
self.reserve_channel = None # no need to open a new one now
remaining_channels = self.active_channels # reset!
yield msg
# DEAL WITH EMPTY CHANNEL (didn't get a message)
else:
n_nonempty_channels_remaining -= 1
# n_nonempty_channels_remaining -= 1
LOGGER.debug("No message received -- switching channels...")
if n_nonempty_channels_remaining == 0:
# don't reset n_nonempty_channels_remaining so we can see if this one is empty
channel = self.add_channel() # try it now
# this new channel will be yielded by inf_channels_gen eventually

# FIXME - this leads to MANY empty channels if the user
# keeps cycling despite coming up empty
# TODO - figure way to close or limit channels, but
# consider that some ack-pending messages need to use
# the channel to ack. Keep one local mapping?
# Currently uses msg._connection_id.
# Maybe just closing a newly-opened channel will suffice?

continue
elif n_nonempty_channels_remaining < 0: # -1
# this means our newly produced channel is empty,
if channel == self.reserve_channel:
# this means our reserve channel came up empty,
# so there's REALLY nothing in the queue
LOGGER.debug(log_msgs.GETMSG_NO_MESSAGE)
self.remove_channel(channel)
channel = next(inf_channels_gen) # try next, next time
# FIXME - this design still allows a lot of new channels, one per final iter...
remaining_channels = self.active_channels # reset!
yield None
else:
channel = next(inf_channels_gen) # try next, now
continue
remaining_channels.remove(channel)
# if n_nonempty_channels_remaining == 0:
if not remaining_channels:
# don't reset n_nonempty_channels_remaining so we can see if this one is empty
# channel = self.open_channel() # try it now
# this new channel will be yielded by inf_channels_gen eventually

# FIXME - this leads to MANY empty channels if the user
# keeps cycling despite coming up empty
# TODO - figure way to close or limit channels, but
# consider that some ack-pending messages need to use
# the channel to ack. Keep one local mapping?
# Currently uses msg._connection_id.
# Maybe just closing a newly-opened channel will suffice?

# try reserve channel
if not self.reserve_channel:
self.reserve_channel = self.open_channel()
channel = self.reserve_channel

# continue
# elif n_nonempty_channels_remaining < 0: # -1
# # this means our newly produced channel is empty,
# # so there's REALLY nothing in the queue
# LOGGER.debug(log_msgs.GETMSG_NO_MESSAGE)
# self.reserve_channel
# channel = next(inf_channels_gen) # try next, next time
# # FIXME - this design still allows a lot of new channels, one per final iter...
# yield None
else:
# channel = next(inf_channels_gen) # try next, now
channel = remaining_channels[0] # TODO - use by priority?
# continue

async def get_message(
self,
Expand All @@ -456,7 +475,7 @@ async def get_message(
) -> Optional[Message]:
"""Get a message from a queue."""
LOGGER.debug(log_msgs.GETMSG_RECEIVE_MESSAGE)
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")

msg = None
Expand All @@ -472,10 +491,12 @@ def _get_channel_by_msg(
self, msg: Message
) -> pika.adapters.blocking_connection.BlockingChannel:
"""Map message to channel."""
matches = [c for c in self.channels if c.channel_number == msg._connection_id]
matches = [
c for c in self.active_channels if c.channel_number == msg._connection_id
]
if not matches:
raise MQClientException(
f"could not map message to channel: {msg} {self.channels}"
f"could not map message to channel: {msg} {self.active_channels}"
)
elif len(matches) > 1:
raise MQClientException(
Expand All @@ -492,7 +513,7 @@ async def ack_message(
) -> None:
"""Ack a message from the queue."""
LOGGER.debug(log_msgs.ACKING_MESSAGE)
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")

channel = self._get_channel_by_msg(msg)
Expand Down Expand Up @@ -527,7 +548,7 @@ async def reject_message(
) -> None:
"""Reject (nack) a message from the queue."""
LOGGER.debug(log_msgs.NACKING_MESSAGE)
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")

channel = self._get_channel_by_msg(msg)
Expand Down Expand Up @@ -572,7 +593,7 @@ async def message_generator(
propagate_error -- should errors from downstream kill the generator? (default: {True})
"""
LOGGER.debug(log_msgs.MSGGEN_ENTERED)
if not self.channels:
if not self.active_channels:
raise MQClientException("queue is not connected")

msg = None
Expand Down

0 comments on commit b25345e

Please sign in to comment.