Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BlockingConnectionPool #190

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aredis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
UnixDomainSocketConnection,
ClusterConnection
)
from aredis.pool import ConnectionPool, ClusterConnectionPool
from aredis.pool import ConnectionPool, ClusterConnectionPool, BlockingConnectionPool
from aredis.exceptions import (
AuthenticationError, BusyLoadingError, ConnectionError,
DataError, InvalidResponse, PubSubError, ReadOnlyError,
Expand All @@ -23,7 +23,7 @@
__all__ = [
'StrictRedis', 'StrictRedisCluster',
'Connection', 'UnixDomainSocketConnection', 'ClusterConnection',
'ConnectionPool', 'ClusterConnectionPool',
'ConnectionPool', 'ClusterConnectionPool', 'BlockingConnectionPool',
'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError',
'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError',
'ResponseError', 'TimeoutError', 'WatchError',
Expand Down
2 changes: 2 additions & 0 deletions aredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ async def execute_command(self, *args, **options):
pool = self.connection_pool
command_name = args[0]
connection = pool.get_connection()
if asyncio.iscoroutine(connection):
connection = await connection
try:
await connection.send_command(*args)
return await self.parse_response(connection, command_name, **options)
Expand Down
5 changes: 5 additions & 0 deletions aredis/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import sys
from itertools import chain
Expand Down Expand Up @@ -104,6 +105,8 @@ async def immediate_execute_command(self, *args, **options):
# if this is the first call, we need a connection
if not conn:
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
self.connection = conn
try:
await conn.send_command(*args)
Expand Down Expand Up @@ -278,6 +281,8 @@ async def execute(self, raise_on_error=True):
conn = self.connection
if not conn:
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
Expand Down
129 changes: 127 additions & 2 deletions aredis/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ def get_connection(self, *args, **kwargs):
try:
connection = self._available_connections.pop()
except IndexError:
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
connection = self.make_connection()
self._in_use_connections.add(connection)
return connection

def make_connection(self):
"""Creates a new connection"""
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
connection = self.connection_class(**self.connection_kwargs)
if self.max_idle_time > self.idle_check_interval > 0:
Expand Down Expand Up @@ -253,6 +253,131 @@ def disconnect(self):
self._created_connections -= 1


class BlockingConnectionPool(ConnectionPool):
"""
Blocking connection pool::

>>> from aredis import StrictRedis
>>> client = StrictRedis(connection_pool=BlockingConnectionPool())

It performs the same function as the default
:py:class:`~aredis.ConnectionPool` implementation, in that,
it maintains a pool of reusable connections that can be shared by
multiple redis clients.

The difference is that, in the event that a client tries to get a
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~aredis.ConnectionError` (as the default
:py:class:`~aredis.ConnectionPool` implementation does), it
makes the client wait ("blocks") for a specified number of seconds until
a connection becomes available.

Use ``max_connections`` to increase / decrease the pool size::

>>> pool = BlockingConnectionPool(max_connections=10)

Use ``timeout`` to tell it either how many seconds to wait for a connection
to become available, or to block forever:

>>> # Block forever.
>>> pool = BlockingConnectionPool(timeout=None)

>>> # Raise a ``ConnectionError`` after five seconds if a connection is
>>> # not available.
>>> pool = BlockingConnectionPool(timeout=5)
"""
def __init__(self, connection_class=Connection, queue_class=asyncio.LifoQueue,
max_connections=None, timeout=20, max_idle_time=0, idle_check_interval=1,
**connection_kwargs):

self.timeout = timeout
self.queue_class = queue_class

max_connections = max_connections or 50

super(BlockingConnectionPool, self).__init__(
connection_class=connection_class, max_connections=max_connections,
max_idle_time=max_idle_time, idle_check_interval=idle_check_interval,
**connection_kwargs)

async def disconnect_on_idle_time_exceeded(self, connection):
while True:
if (time.time() - connection.last_active_at > self.max_idle_time
and not connection.awaiting_response):
# Unlike the non blocking pool, we don't free the connection object,
# but always reuse it
connection.disconnect()
break
await asyncio.sleep(self.idle_check_interval)

def reset(self):
self._pool = self.queue_class(self.max_connections)
while True:
try:
self._pool.put_nowait(None)
except asyncio.QueueFull:
break

super(BlockingConnectionPool, self).reset()

async def get_connection(self, *args, **kwargs):
"""Gets a connection from the pool"""
self._checkpid()

connection = None

try:
connection = await asyncio.wait_for(
self._pool.get(),
self.timeout
)
except asyncio.TimeoutError:
raise ConnectionError("No connection available.")

if connection is None:
connection = self.make_connection()

self._in_use_connections.add(connection)
return connection

def release(self, connection):
"""Releases the connection back to the pool"""
self._checkpid()
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
# discard connection with unread response
if connection.awaiting_response:
connection.disconnect()
connection = None

try:
self._pool.put_nowait(connection)
except asyncio.QueueFull:
# perhaps the pool have been reset() ?
pass

def disconnect(self):
"""Closes all connections in the pool"""
pooled_connections = []
while True:
try:
pooled_connections.append(self._pool.get_nowait())
except asyncio.QueueEmpty:
break

for conn in pooled_connections:
try:
self._pool.put_nowait(conn)
except asyncio.QueueFull:
pass

all_conns = chain(pooled_connections,
self._in_use_connections)
for connection in all_conns:
if connection is not None:
connection.disconnect()

class ClusterConnectionPool(ConnectionPool):
"""Custom connection pool for rediscluster"""
RedisClusterDefaultTimeout = None
Expand Down
27 changes: 23 additions & 4 deletions aredis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ def __init__(self, connection_pool, ignore_subscribe_messages=False):
self.connection = None
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
conn = connection_pool.get_connection('pubsub')
self.reset()

async def _ensure_encoding(self):
if hasattr(self, "encoding"):
return

conn = self.connection_pool.get_connection('pubsub')
if asyncio.iscoroutine(conn):
conn = await conn
try:
self.encoding = conn.encoding
self.decode_responses = conn.decode_responses
finally:
connection_pool.release(conn)
self.reset()
self.connection_pool.release(conn)

def __del__(self):
try:
Expand Down Expand Up @@ -93,13 +100,17 @@ def subscribed(self):

async def execute_command(self, *args, **kwargs):
"""Executes a publish/subscribe command"""
await self._ensure_encoding()

# NOTE: don't parse the response in this function -- it could pull a
# legitimate message off the stack if the connection is already
# subscribed to one or more channels

if self.connection is None:
self.connection = self.connection_pool.get_connection()
conn = self.connection_pool.get_connection()
if asyncio.iscoroutine(conn):
conn = await conn
self.connection = conn
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
Expand Down Expand Up @@ -151,6 +162,8 @@ async def psubscribe(self, *args, **kwargs):
received on that pattern rather than producing a message via
``listen()``.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
new_patterns = {}
Expand All @@ -169,6 +182,8 @@ async def punsubscribe(self, *args):
Unsubscribes from the supplied patterns. If empy, unsubscribe from
all patterns.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
return await self.execute_command('PUNSUBSCRIBE', *args)
Expand All @@ -181,6 +196,8 @@ async def subscribe(self, *args, **kwargs):
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
new_channels = {}
Expand All @@ -199,6 +216,8 @@ async def unsubscribe(self, *args):
Unsubscribes from the supplied channels. If empty, unsubscribe from
all channels
"""
await self._ensure_encoding()

if args:
args = list_or_args(args[0], args[1:])
return await self.execute_command('UNSUBSCRIBE', *args)
Expand Down
86 changes: 86 additions & 0 deletions tests/client/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,92 @@ async def test_connection_idle_check(self, event_loop):
assert conn._writer is None and conn._reader is None


class TestBlockingConnectionPool:
def get_pool(self, connection_kwargs=None, max_connections=None,
connection_class=DummyConnection, timeout=None):
connection_kwargs = connection_kwargs or {}
pool = aredis.BlockingConnectionPool(
connection_class=connection_class,
max_connections=max_connections,
timeout=timeout,
**connection_kwargs)
return pool

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
pool = self.get_pool(connection_kwargs=connection_kwargs)
connection = await pool.get_connection()
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_multiple_connections(self):
pool = self.get_pool()
c1 = await pool.get_connection()
c2 = await pool.get_connection()
assert c1 != c2

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_max_connections_timeout(self):
pool = self.get_pool(max_connections=2, timeout=0.1)
await pool.get_connection()
await pool.get_connection()
with pytest.raises(ConnectionError):
await pool.get_connection()

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_max_connections_no_timeout(self):
pool = self.get_pool(max_connections=2)
await pool.get_connection()
released_conn = await pool.get_connection()
def releaser():
pool.release(released_conn)

loop = asyncio.get_running_loop()
loop.call_later(0.2, releaser)
new_conn = await pool.get_connection()
assert new_conn == released_conn

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_reuse_previously_released_connection(self):
pool = self.get_pool()
c1 = await pool.get_connection()
pool.release(c1)
c2 = await pool.get_connection()
assert c1 == c2

def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=aredis.Connection)
expected = 'BlockingConnectionPool<Connection<host=localhost,port=6379,db=1>>'
assert repr(pool) == expected

def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=aredis.UnixDomainSocketConnection)
expected = 'BlockingConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected

@pytest.mark.asyncio(forbid_global_loop=True)
async def test_connection_idle_check(self, event_loop):
rs = aredis.StrictRedis(host='127.0.0.1', port=6379, db=0,
max_idle_time=0.2, idle_check_interval=0.1)
await rs.info()
assert len(rs.connection_pool._in_use_connections) == 0
conn = rs.connection_pool.get_connection()
last_active_at = conn.last_active_at
rs.connection_pool.release(conn)
await asyncio.sleep(0.3)
assert len(rs.connection_pool._in_use_connections) == 0
assert last_active_at == conn.last_active_at
assert conn._writer is None and conn._reader is None
new_conn = rs.connection_pool.get_connection()
assert conn != new_conn


class TestConnectionPoolURLParsing:
def test_defaults(self):
pool = aredis.ConnectionPool.from_url('redis://localhost')
Expand Down