Skip to content

Commit

Permalink
feat: add BlockingConnectionPool
Browse files Browse the repository at this point in the history
  • Loading branch information
sileht committed Mar 25, 2021
1 parent b46e671 commit 925b4a2
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 8 deletions.
4 changes: 2 additions & 2 deletions aredis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
UnixDomainSocketConnection,
ClusterConnection
)
from aredis.pool import ConnectionPool, ClusterConnectionPool
from aredis.pool import ConnectionPool, ClusterConnectionPool, BlockingConnectionPool
from aredis.exceptions import (
AuthenticationError, BusyLoadingError, ConnectionError,
DataError, InvalidResponse, PubSubError, ReadOnlyError,
Expand All @@ -23,7 +23,7 @@
__all__ = [
'StrictRedis', 'StrictRedisCluster',
'Connection', 'UnixDomainSocketConnection', 'ClusterConnection',
'ConnectionPool', 'ClusterConnectionPool',
'ConnectionPool', 'ClusterConnectionPool', 'BlockingConnectionPool',
'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError',
'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError',
'ResponseError', 'TimeoutError', 'WatchError',
Expand Down
2 changes: 2 additions & 0 deletions aredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ async def execute_command(self, *args, **options):
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection()
if asyncio.iscoroutine(connection):
connection = await connection
try:
await connection.send_command(*args)
return await self.parse_response(connection, command_name, **options)
Expand Down
5 changes: 5 additions & 0 deletions aredis/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import sys
from itertools import chain
Expand Down Expand Up @@ -104,6 +105,8 @@ async def immediate_execute_command(self, *args, **options):
# if this is the first call, we need a connection
if not conn:
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
self.connection = conn
try:
await conn.send_command(*args)
Expand Down Expand Up @@ -278,6 +281,8 @@ async def execute(self, raise_on_error=True):
conn = self.connection
if not conn:
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
Expand Down
129 changes: 127 additions & 2 deletions aredis/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ def get_connection(self, *args, **kwargs):
try:
connection = self._available_connections.pop()
except IndexError:
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection

def make_connection(self):
"""Creates a new connection"""
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
connection = self.connection_class(**self.connection_kwargs)
if self.max_idle_time > self.idle_check_interval > 0:
Expand Down Expand Up @@ -253,6 +253,131 @@ def disconnect(self):
self._created_connections -= 1


class BlockingConnectionPool(ConnectionPool):
"""
Blocking connection pool::
>>> from aredis import StrictRedis
>>> client = StrictRedis(connection_pool=BlockingConnectionPool())
It performs the same function as the default
:py:class:`~aredis.ConnectionPool` implementation, in that,
it maintains a pool of reusable connections that can be shared by
multiple redis clients.
The difference is that, in the event that a client tries to get a
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~aredis.ConnectionError` (as the default
:py:class:`~aredis.ConnectionPool` implementation does), it
makes the client wait ("blocks") for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
>>> pool = BlockingConnectionPool(max_connections=10)
Use ``timeout`` to tell it either how many seconds to wait for a connection
to become available, or to block forever:
>>> # Block forever.
>>> pool = BlockingConnectionPool(timeout=None)
>>> # Raise a ``ConnectionError`` after five seconds if a connection is
>>> # not available.
>>> pool = BlockingConnectionPool(timeout=5)
"""
def __init__(self, connection_class=Connection, queue_class=asyncio.LifoQueue,
max_connections=None, timeout=20, max_idle_time=0, idle_check_interval=1,
**connection_kwargs):

self.timeout = timeout
self.queue_class = queue_class

max_connections = max_connections or 50

super(BlockingConnectionPool, self).__init__(
connection_class=connection_class, max_connections=max_connections,
max_idle_time=max_idle_time, idle_check_interval=idle_check_interval,
**connection_kwargs)

async def disconnect_on_idle_time_exceeded(self, connection):
while True:
if (time.time() - connection.last_active_at > self.max_idle_time
and not connection.awaiting_response):
# Unlike the non blocking pool, we don't free the connection object,
# but always reuse it
connection.disconnect()
break
await asyncio.sleep(self.idle_check_interval)

def reset(self):
self._pool = self.queue_class(self.max_connections)
while True:
try:
self._pool.put_nowait(None)
except asyncio.QueueFull:
break

super(BlockingConnectionPool, self).reset()

async def get_connection(self, *args, **kwargs):
"""Gets a connection from the pool"""
self._checkpid()

connection = None

try:
connection = await asyncio.wait_for(
self._pool.get(),
self.timeout
)
except asyncio.TimeoutError:
raise ConnectionError("No connection available.")

if connection is None:
connection = self.make_connection()

self._in_use_connections.add(connection)
return connection

def release(self, connection):
"""Releases the connection back to the pool"""
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
# discard connection with unread response
if connection.awaiting_response:
connection.disconnect()
connection = None

try:
self._pool.put_nowait(connection)
except asyncio.QueueFull:
# perhaps the pool have been reset() ?
pass

def disconnect(self):
"""Closes all connections in the pool"""
pooled_connections = []
while True:
try:
pooled_connections.append(self._pool.get_nowait())
except asyncio.QueueEmpty:
break

for conn in pooled_connections:
try:
self._pool.put_nowait(conn)
except asyncio.QueueFull:
pass

all_conns = chain(pooled_connections,
self._in_use_connections)
for connection in all_conns:
if connection is not None:
connection.disconnect()

class ClusterConnectionPool(ConnectionPool):
"""Custom connection pool for rediscluster"""
RedisClusterDefaultTimeout = None
Expand Down
27 changes: 23 additions & 4 deletions aredis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ def __init__(self, connection_pool, ignore_subscribe_messages=False):
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
conn = connection_pool.get_connection('pubsub')
self.reset()

async def _ensure_encoding(self):
if hasattr(self, "encoding"):
return

conn = self.connection_pool.get_connection('pubsub')
if asyncio.iscoroutine(conn):
conn = await conn
try:
self.encoding = conn.encoding
self.decode_responses = conn.decode_responses
finally:
connection_pool.release(conn)
self.reset()
self.connection_pool.release(conn)

def __del__(self):
try:
Expand Down Expand Up @@ -93,13 +100,17 @@ def subscribed(self):

async def execute_command(self, *args, **kwargs):
"""Executes a publish/subscribe command"""
await self._ensure_encoding()

# NOTE: don't parse the response in this function -- it could pull a
# legitimate message off the stack if the connection is already
# subscribed to one or more channels

if self.connection is None:
self.connection = self.connection_pool.get_connection()
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
self.connection = conn
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
Expand Down Expand Up @@ -151,6 +162,8 @@ async def psubscribe(self, *args, **kwargs):
received on that pattern rather than producing a message via
``listen()``.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
new_patterns = {}
Expand All @@ -169,6 +182,8 @@ async def punsubscribe(self, *args):
Unsubscribes from the supplied patterns. If empy, unsubscribe from
all patterns.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
return await self.execute_command('PUNSUBSCRIBE', *args)
Expand All @@ -181,6 +196,8 @@ async def subscribe(self, *args, **kwargs):
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
new_channels = {}
Expand All @@ -199,6 +216,8 @@ async def unsubscribe(self, *args):
Unsubscribes from the supplied channels. If empty, unsubscribe from
all channels
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
return await self.execute_command('UNSUBSCRIBE', *args)
Expand Down
86 changes: 86 additions & 0 deletions tests/client/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,92 @@ async def test_connection_idle_check(self, event_loop):
assert conn._writer is None and conn._reader is None


class TestBlockingConnectionPool:
def get_pool(self, connection_kwargs=None, max_connections=None,
connection_class=DummyConnection, timeout=None):
connection_kwargs = connection_kwargs or {}
pool = aredis.BlockingConnectionPool(
connection_class=connection_class,
max_connections=max_connections,
timeout=timeout,
**connection_kwargs)
return pool

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = await pool.get_connection()
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_multiple_connections(self):
pool = self.get_pool()
c1 = await pool.get_connection()
c2 = await pool.get_connection()
assert c1 != c2

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_max_connections_timeout(self):
pool = self.get_pool(max_connections=2, timeout=0.1)
await pool.get_connection()
await pool.get_connection()
with pytest.raises(ConnectionError):
await pool.get_connection()

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_max_connections_no_timeout(self):
pool = self.get_pool(max_connections=2)
await pool.get_connection()
released_conn = await pool.get_connection()
def releaser():
pool.release(released_conn)

loop = asyncio.get_running_loop()
loop.call_later(0.2, releaser)
new_conn = await pool.get_connection()
assert new_conn == released_conn

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = await pool.get_connection()
pool.release(c1)
c2 = await pool.get_connection()
assert c1 == c2

def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=aredis.Connection)
expected = 'BlockingConnectionPool<Connection<host=localhost,port=6379,db=1>>'
assert repr(pool) == expected

def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=aredis.UnixDomainSocketConnection)
expected = 'BlockingConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_connection_idle_check(self, event_loop):
rs = aredis.StrictRedis(host='127.0.0.1', port=6379, db=0,
max_idle_time=0.2, idle_check_interval=0.1)
await rs.info()
assert len(rs.connection_pool._in_use_connections) == 0
conn = rs.connection_pool.get_connection()
last_active_at = conn.last_active_at
rs.connection_pool.release(conn)
await asyncio.sleep(0.3)
assert len(rs.connection_pool._in_use_connections) == 0
assert last_active_at == conn.last_active_at
assert conn._writer is None and conn._reader is None
new_conn = rs.connection_pool.get_connection()
assert conn != new_conn


class TestConnectionPoolURLParsing:
def test_defaults(self):
pool = aredis.ConnectionPool.from_url('redis://localhost')
Expand Down

0 comments on commit 925b4a2

Please sign in to comment.