diff --git a/aredis/__init__.py b/aredis/__init__.py index 864aec51..d56bc6e0 100644 --- a/aredis/__init__.py +++ b/aredis/__init__.py @@ -4,7 +4,7 @@ UnixDomainSocketConnection, ClusterConnection ) -from aredis.pool import ConnectionPool, ClusterConnectionPool +from aredis.pool import ConnectionPool, ClusterConnectionPool, BlockingConnectionPool from aredis.exceptions import ( AuthenticationError, BusyLoadingError, ConnectionError, DataError, InvalidResponse, PubSubError, ReadOnlyError, @@ -23,7 +23,7 @@ __all__ = [ 'StrictRedis', 'StrictRedisCluster', 'Connection', 'UnixDomainSocketConnection', 'ClusterConnection', - 'ConnectionPool', 'ClusterConnectionPool', + 'ConnectionPool', 'ClusterConnectionPool', 'BlockingConnectionPool', 'AuthenticationError', 'BusyLoadingError', 'ConnectionError', 'DataError', 'InvalidResponse', 'PubSubError', 'ReadOnlyError', 'RedisError', 'ResponseError', 'TimeoutError', 'WatchError', diff --git a/aredis/client.py b/aredis/client.py index 91973db5..91517d67 100644 --- a/aredis/client.py +++ b/aredis/client.py @@ -151,6 +151,8 @@ async def execute_command(self, *args, **options): pool = self.connection_pool command_name = args[0] connection = pool.get_connection() + if asyncio.iscoroutine(connection): + connection = await connection try: await connection.send_command(*args) return await self.parse_response(connection, command_name, **options) diff --git a/aredis/pipeline.py b/aredis/pipeline.py index 90368d5a..38c49003 100644 --- a/aredis/pipeline.py +++ b/aredis/pipeline.py @@ -1,3 +1,4 @@ +import asyncio import inspect import sys from itertools import chain @@ -104,6 +105,8 @@ async def immediate_execute_command(self, *args, **options): # if this is the first call, we need a connection if not conn: conn = self.connection_pool.get_connection() + if asyncio.iscoroutine(conn): + conn = await conn self.connection = conn try: await conn.send_command(*args) @@ -278,6 +281,8 @@ async def execute(self, raise_on_error=True): conn = self.connection if not conn: conn = self.connection_pool.get_connection() + if asyncio.iscoroutine(conn): + conn = await conn # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/aredis/pool.py b/aredis/pool.py index 1fbb2cc7..698d94fe 100644 --- a/aredis/pool.py +++ b/aredis/pool.py @@ -216,14 +216,14 @@ def get_connection(self, *args, **kwargs): try: connection = self._available_connections.pop() except IndexError: + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") connection = self.make_connection() self._in_use_connections.add(connection) return connection def make_connection(self): """Creates a new connection""" - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") self._created_connections += 1 connection = self.connection_class(**self.connection_kwargs) if self.max_idle_time > self.idle_check_interval > 0: @@ -253,6 +253,131 @@ def disconnect(self): self._created_connections -= 1 +class BlockingConnectionPool(ConnectionPool): + """ + Blocking connection pool:: + + >>> from aredis import StrictRedis + >>> client = StrictRedis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~aredis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients. + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~aredis.ConnectionError` (as the default + :py:class:`~aredis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + def __init__(self, connection_class=Connection, queue_class=asyncio.LifoQueue, + max_connections=None, timeout=20, max_idle_time=0, idle_check_interval=1, + **connection_kwargs): + + self.timeout = timeout + self.queue_class = queue_class + + max_connections = max_connections or 50 + + super(BlockingConnectionPool, self).__init__( + connection_class=connection_class, max_connections=max_connections, + max_idle_time=max_idle_time, idle_check_interval=idle_check_interval, + **connection_kwargs) + + async def disconnect_on_idle_time_exceeded(self, connection): + while True: + if (time.time() - connection.last_active_at > self.max_idle_time + and not connection.awaiting_response): + # Unlike the non blocking pool, we don't free the connection object, + # but always reuse it + connection.disconnect() + break + await asyncio.sleep(self.idle_check_interval) + + def reset(self): + self._pool = self.queue_class(self.max_connections) + while True: + try: + self._pool.put_nowait(None) + except asyncio.QueueFull: + break + + super(BlockingConnectionPool, self).reset() + + async def get_connection(self, *args, **kwargs): + """Gets a connection from the pool""" + self._checkpid() + + connection = None + + try: + connection = await asyncio.wait_for( + self._pool.get(), + self.timeout + ) + except asyncio.TimeoutError: + raise ConnectionError("No connection available.") + + if connection is None: + connection = self.make_connection() + + self._in_use_connections.add(connection) + return connection + + def release(self, connection): + """Releases the connection back to the pool""" + self._checkpid() + if connection.pid != self.pid: + return + self._in_use_connections.remove(connection) + # discard connection with unread response + if connection.awaiting_response: + connection.disconnect() + connection = None + + try: + self._pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool have been reset() ? + pass + + def disconnect(self): + """Closes all connections in the pool""" + pooled_connections = [] + while True: + try: + pooled_connections.append(self._pool.get_nowait()) + except asyncio.QueueEmpty: + break + + for conn in pooled_connections: + try: + self._pool.put_nowait(conn) + except asyncio.QueueFull: + pass + + all_conns = chain(pooled_connections, + self._in_use_connections) + for connection in all_conns: + if connection is not None: + connection.disconnect() + class ClusterConnectionPool(ConnectionPool): """Custom connection pool for rediscluster""" RedisClusterDefaultTimeout = None diff --git a/aredis/pubsub.py b/aredis/pubsub.py index 7dbb9627..3fa479b9 100644 --- a/aredis/pubsub.py +++ b/aredis/pubsub.py @@ -26,13 +26,20 @@ def __init__(self, connection_pool, ignore_subscribe_messages=False): self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection('pubsub') + self.reset() + + async def _ensure_encoding(self): + if hasattr(self, "encoding"): + return + + conn = self.connection_pool.get_connection('pubsub') + if asyncio.iscoroutine(conn): + conn = await conn try: self.encoding = conn.encoding self.decode_responses = conn.decode_responses finally: - connection_pool.release(conn) - self.reset() + self.connection_pool.release(conn) def __del__(self): try: @@ -93,13 +100,17 @@ def subscribed(self): async def execute_command(self, *args, **kwargs): """Executes a publish/subscribe command""" + await self._ensure_encoding() # NOTE: don't parse the response in this function -- it could pull a # legitimate message off the stack if the connection is already # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection() + conn = self.connection_pool.get_connection() + if asyncio.iscoroutine(conn): + conn = await conn + self.connection = conn # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -151,6 +162,8 @@ async def psubscribe(self, *args, **kwargs): received on that pattern rather than producing a message via ``listen()``. """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) new_patterns = {} @@ -169,6 +182,8 @@ async def punsubscribe(self, *args): Unsubscribes from the supplied patterns. If empy, unsubscribe from all patterns. """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) return await self.execute_command('PUNSUBSCRIBE', *args) @@ -181,6 +196,8 @@ async def subscribe(self, *args, **kwargs): that channel rather than producing a message via ``listen()`` or ``get_message()``. """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) new_channels = {} @@ -199,6 +216,8 @@ async def unsubscribe(self, *args): Unsubscribes from the supplied channels. If empty, unsubscribe from all channels """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) return await self.execute_command('UNSUBSCRIBE', *args) diff --git a/tests/client/test_connection_pool.py b/tests/client/test_connection_pool.py index f07d903c..c9a15857 100644 --- a/tests/client/test_connection_pool.py +++ b/tests/client/test_connection_pool.py @@ -87,6 +87,92 @@ async def test_connection_idle_check(self, event_loop): assert conn._writer is None and conn._reader is None +class TestBlockingConnectionPool: + def get_pool(self, connection_kwargs=None, max_connections=None, + connection_class=DummyConnection, timeout=None): + connection_kwargs = connection_kwargs or {} + pool = aredis.BlockingConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs) + return pool + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_connection_creation(self): + connection_kwargs = {'foo': 'bar', 'biz': 'baz'} + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection() + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_multiple_connections(self): + pool = self.get_pool() + c1 = await pool.get_connection() + c2 = await pool.get_connection() + assert c1 != c2 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_max_connections_timeout(self): + pool = self.get_pool(max_connections=2, timeout=0.1) + await pool.get_connection() + await pool.get_connection() + with pytest.raises(ConnectionError): + await pool.get_connection() + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_max_connections_no_timeout(self): + pool = self.get_pool(max_connections=2) + await pool.get_connection() + released_conn = await pool.get_connection() + def releaser(): + pool.release(released_conn) + + loop = asyncio.get_running_loop() + loop.call_later(0.2, releaser) + new_conn = await pool.get_connection() + assert new_conn == released_conn + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_reuse_previously_released_connection(self): + pool = self.get_pool() + c1 = await pool.get_connection() + pool.release(c1) + c2 = await pool.get_connection() + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=aredis.Connection) + expected = 'BlockingConnectionPool>' + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {'path': '/abc', 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=aredis.UnixDomainSocketConnection) + expected = 'BlockingConnectionPool>' + assert repr(pool) == expected + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_connection_idle_check(self, event_loop): + rs = aredis.StrictRedis(host='127.0.0.1', port=6379, db=0, + max_idle_time=0.2, idle_check_interval=0.1) + await rs.info() + assert len(rs.connection_pool._in_use_connections) == 0 + conn = rs.connection_pool.get_connection() + last_active_at = conn.last_active_at + rs.connection_pool.release(conn) + await asyncio.sleep(0.3) + assert len(rs.connection_pool._in_use_connections) == 0 + assert last_active_at == conn.last_active_at + assert conn._writer is None and conn._reader is None + new_conn = rs.connection_pool.get_connection() + assert conn != new_conn + + class TestConnectionPoolURLParsing: def test_defaults(self): pool = aredis.ConnectionPool.from_url('redis://localhost')