Skip to content

Commit

Permalink
Add BlockingConnectionPool
Browse files Browse the repository at this point in the history
Derived from NoneGG/aredis#190
  • Loading branch information
alisaifee committed Jan 17, 2022
1 parent dddb0a8 commit 8070ac0
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 25 deletions.
3 changes: 2 additions & 1 deletion coredis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
TryAgainError,
WatchError,
)
from coredis.pool import ClusterConnectionPool, ConnectionPool
from coredis.pool import BlockingConnectionPool, ClusterConnectionPool, ConnectionPool

from . import _version

Expand All @@ -41,6 +41,7 @@
"Connection",
"UnixDomainSocketConnection",
"ClusterConnection",
"BlockingConnectionPool",
"ConnectionPool",
"ClusterConnectionPool",
"AskError",
Expand Down
2 changes: 1 addition & 1 deletion coredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def execute_command(self, *args, **options):
"""Executes a command and returns a parsed response"""
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection()
connection = await pool.get_connection()
try:
await connection.send_command(*args)

Expand Down
4 changes: 2 additions & 2 deletions coredis/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def immediate_execute_command(self, *args, **options):
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection()
self.connection = conn
try:
await conn.send_command(*args)
Expand Down Expand Up @@ -297,7 +297,7 @@ async def execute(self, raise_on_error=True):

conn = self.connection
if not conn:
conn = self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection()
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
Expand Down
140 changes: 137 additions & 3 deletions coredis/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,14 @@ def _checkpid(self):
self.disconnect()
self.reset()

def get_connection(self, *args, **kwargs):
async def get_connection(self, *args, **kwargs):
"""Gets a connection from the pool"""
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
connection = self.make_connection()
self._in_use_connections.add(connection)

Expand All @@ -294,8 +296,6 @@ def get_connection(self, *args, **kwargs):
def make_connection(self):
"""Creates a new connection"""

if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
connection = self.connection_class(**self.connection_kwargs)

Expand Down Expand Up @@ -329,6 +329,140 @@ def disconnect(self):
self._created_connections -= 1


class BlockingConnectionPool(ConnectionPool):
"""
Blocking connection pool::
>>> from aredis import StrictRedis
>>> client = StrictRedis(connection_pool=BlockingConnectionPool())
It performs the same function as the default
:py:class:`~aredis.ConnectionPool` implementation, in that,
it maintains a pool of reusable connections that can be shared by
multiple redis clients.
The difference is that, in the event that a client tries to get a
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~aredis.ConnectionError` (as the default
:py:class:`~aredis.ConnectionPool` implementation does), it
makes the client wait ("blocks") for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
>>> pool = BlockingConnectionPool(max_connections=10)
Use ``timeout`` to tell it either how many seconds to wait for a connection
to become available, or to block forever:
>>> # Block forever.
>>> pool = BlockingConnectionPool(timeout=None)
>>> # Raise a ``ConnectionError`` after five seconds if a connection is
>>> # not available.
>>> pool = BlockingConnectionPool(timeout=5)
"""

def __init__(
self,
connection_class=Connection,
queue_class=asyncio.LifoQueue,
max_connections=None,
timeout=20,
max_idle_time=0,
idle_check_interval=1,
**connection_kwargs
):

self.timeout = timeout
self.queue_class = queue_class

max_connections = max_connections or 50

super(BlockingConnectionPool, self).__init__(
connection_class=connection_class,
max_connections=max_connections,
max_idle_time=max_idle_time,
idle_check_interval=idle_check_interval,
**connection_kwargs
)

async def disconnect_on_idle_time_exceeded(self, connection):
while True:
if (
time.time() - connection.last_active_at > self.max_idle_time
and not connection.awaiting_response
):
# Unlike the non blocking pool, we don't free the connection object,
# but always reuse it
connection.disconnect()

break
await asyncio.sleep(self.idle_check_interval)

def reset(self):
self._pool = self.queue_class(self.max_connections)

while True:
try:
self._pool.put_nowait(None)
except asyncio.QueueFull:
break

super(BlockingConnectionPool, self).reset()

async def get_connection(self, *args, **kwargs):
"""Gets a connection from the pool"""
self._checkpid()

connection = None

try:
connection = await asyncio.wait_for(self._pool.get(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("No connection available.")

if connection is None:
connection = self.make_connection()

self._in_use_connections.add(connection)

return connection

def release(self, connection):
"""Releases the connection back to the pool"""
self._checkpid()

if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
# discard connection with unread response

if connection.awaiting_response:
connection.disconnect()
connection = None

try:
self._pool.put_nowait(connection)
except asyncio.QueueFull:
# perhaps the pool have been reset() ?
pass

def disconnect(self):
"""Closes all connections in the pool"""
pooled_connections = []

while True:
try:
pooled_connections.append(self._pool.get_nowait())
except asyncio.QueueEmpty:
break

for conn in pooled_connections:
try:
self._pool.put_nowait(conn)
except asyncio.QueueFull:
pass

all_conns = chain(pooled_connections, self._in_use_connections)

for connection in all_conns:
if connection is not None:
connection.disconnect()


class ClusterConnectionPool(ConnectionPool):
"""Custom connection pool for rediscluster"""

Expand Down
28 changes: 19 additions & 9 deletions coredis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ def __init__(self, connection_pool, ignore_subscribe_messages=False):
self.connection_pool = connection_pool
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
conn = connection_pool.get_connection("pubsub")
self.reset()

async def _ensure_encoding(self):
if hasattr(self, "encoding"):
return

conn = await self.connection_pool.get_connection("pubsub")
try:
self.encoding = conn.encoding
self.decode_responses = conn.decode_responses
finally:
connection_pool.release(conn)
self.reset()
self.connection_pool.release(conn)

def __del__(self):
try:
Expand All @@ -54,6 +57,7 @@ def close(self):

async def on_connect(self, connection):
"""Re-subscribe to any channels and patterns previously subscribed to"""

if self.channels:
channels = {}

Expand Down Expand Up @@ -98,10 +102,10 @@ async def execute_command(self, *args, **kwargs):
# legitimate message off the stack if the connection is already
# subscribed to one or more channels

await self._ensure_encoding()

if self.connection is None:
self.connection = self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection = await self.connection_pool.get_connection()
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
await self._execute(connection, connection.send_command, *args)
Expand Down Expand Up @@ -159,6 +163,7 @@ async def psubscribe(self, *args, **kwargs):
received on that pattern rather than producing a message via
``listen()``.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
Expand All @@ -180,6 +185,7 @@ async def punsubscribe(self, *args):
Unsubscribes from the supplied patterns. If empy, unsubscribe from
all patterns.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
Expand All @@ -195,6 +201,8 @@ async def subscribe(self, *args, **kwargs):
``get_message()``.
"""

await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
new_channels = {}
Expand All @@ -216,6 +224,8 @@ async def unsubscribe(self, *args):
all channels
"""

await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])

Expand Down Expand Up @@ -385,7 +395,7 @@ async def execute_command(self, *args, **kwargs):
await self.connection_pool.initialize()

if self.connection is None:
self.connection = self.connection_pool.get_connection(
self.connection = await self.connection_pool.get_connection(
"pubsub",
channel=args[1],
)
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Connection Classes
.. autoclass:: Connection
.. autoclass:: UnixDomainSocketConnection
.. autoclass:: ClusterConnection

Connection Pools
^^^^^^^^^^^^^^^^
.. autoclass:: BlockingConnectionPool
.. autoclass:: ConnectionPool
.. autoclass:: ClusterConnectionPool

Expand Down
Loading

0 comments on commit 8070ac0

Please sign in to comment.