diff --git a/coredis/__init__.py b/coredis/__init__.py index ee5bec03..3e85c93a 100644 --- a/coredis/__init__.py +++ b/coredis/__init__.py @@ -30,7 +30,7 @@ TryAgainError, WatchError, ) -from coredis.pool import ClusterConnectionPool, ConnectionPool +from coredis.pool import BlockingConnectionPool, ClusterConnectionPool, ConnectionPool from . import _version @@ -41,6 +41,7 @@ "Connection", "UnixDomainSocketConnection", "ClusterConnection", + "BlockingConnectionPool", "ConnectionPool", "ClusterConnectionPool", "AskError", diff --git a/coredis/client.py b/coredis/client.py index 5b35caaa..acc3ae0a 100644 --- a/coredis/client.py +++ b/coredis/client.py @@ -216,7 +216,7 @@ async def execute_command(self, *args, **options): """Executes a command and returns a parsed response""" pool = self.connection_pool command_name = args[0] - connection = pool.get_connection() + connection = await pool.get_connection() try: await connection.send_command(*args) diff --git a/coredis/pipeline.py b/coredis/pipeline.py index 400e6e48..3067de87 100644 --- a/coredis/pipeline.py +++ b/coredis/pipeline.py @@ -118,7 +118,7 @@ async def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = self.connection_pool.get_connection() + conn = await self.connection_pool.get_connection() self.connection = conn try: await conn.send_command(*args) @@ -297,7 +297,7 @@ async def execute(self, raise_on_error=True): conn = self.connection if not conn: - conn = self.connection_pool.get_connection() + conn = await self.connection_pool.get_connection() # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/coredis/pool.py b/coredis/pool.py index 93841154..87104891 100644 --- a/coredis/pool.py +++ b/coredis/pool.py @@ -280,12 +280,14 @@ def _checkpid(self): self.disconnect() self.reset() - def get_connection(self, *args, **kwargs): + async def get_connection(self, *args, **kwargs): """Gets a connection from the pool""" self._checkpid() try: connection = self._available_connections.pop() except IndexError: + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") connection = self.make_connection() self._in_use_connections.add(connection) @@ -294,8 +296,6 @@ def get_connection(self, *args, **kwargs): def make_connection(self): """Creates a new connection""" - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") self._created_connections += 1 connection = self.connection_class(**self.connection_kwargs) @@ -329,6 +329,140 @@ def disconnect(self): self._created_connections -= 1 +class BlockingConnectionPool(ConnectionPool): + """ + Blocking connection pool:: + >>> from aredis import StrictRedis + >>> client = StrictRedis(connection_pool=BlockingConnectionPool()) + It performs the same function as the default + :py:class:`~aredis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients. + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~aredis.ConnectionError` (as the default + :py:class:`~aredis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + Use ``max_connections`` to increase / decrease the pool size:: + >>> pool = BlockingConnectionPool(max_connections=10) + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + connection_class=Connection, + queue_class=asyncio.LifoQueue, + max_connections=None, + timeout=20, + max_idle_time=0, + idle_check_interval=1, + **connection_kwargs + ): + + self.timeout = timeout + self.queue_class = queue_class + + max_connections = max_connections or 50 + + super(BlockingConnectionPool, self).__init__( + connection_class=connection_class, + max_connections=max_connections, + max_idle_time=max_idle_time, + idle_check_interval=idle_check_interval, + **connection_kwargs + ) + + async def disconnect_on_idle_time_exceeded(self, connection): + while True: + if ( + time.time() - connection.last_active_at > self.max_idle_time + and not connection.awaiting_response + ): + # Unlike the non blocking pool, we don't free the connection object, + # but always reuse it + connection.disconnect() + + break + await asyncio.sleep(self.idle_check_interval) + + def reset(self): + self._pool = self.queue_class(self.max_connections) + + while True: + try: + self._pool.put_nowait(None) + except asyncio.QueueFull: + break + + super(BlockingConnectionPool, self).reset() + + async def get_connection(self, *args, **kwargs): + """Gets a connection from the pool""" + self._checkpid() + + connection = None + + try: + connection = await asyncio.wait_for(self._pool.get(), self.timeout) + except asyncio.TimeoutError: + raise ConnectionError("No connection available.") + + if connection is None: + connection = self.make_connection() + + self._in_use_connections.add(connection) + + return connection + + def release(self, connection): + """Releases the connection back to the pool""" + self._checkpid() + + if connection.pid != self.pid: + return + self._in_use_connections.remove(connection) + # discard connection with unread response + + if connection.awaiting_response: + connection.disconnect() + connection = None + + try: + self._pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool have been reset() ? + pass + + def disconnect(self): + """Closes all connections in the pool""" + pooled_connections = [] + + while True: + try: + pooled_connections.append(self._pool.get_nowait()) + except asyncio.QueueEmpty: + break + + for conn in pooled_connections: + try: + self._pool.put_nowait(conn) + except asyncio.QueueFull: + pass + + all_conns = chain(pooled_connections, self._in_use_connections) + + for connection in all_conns: + if connection is not None: + connection.disconnect() + + class ClusterConnectionPool(ConnectionPool): """Custom connection pool for rediscluster""" @@ -454,7 +588,7 @@ def _checkpid(self): self.disconnect() self.reset() - def get_connection(self, command_name, *keys, **options): + async def get_connection(self, command_name, *keys, **options): # Only pubsub command/connection should be allowed here if command_name != "pubsub": diff --git a/coredis/pubsub.py b/coredis/pubsub.py index 12a8152d..5af503b7 100644 --- a/coredis/pubsub.py +++ b/coredis/pubsub.py @@ -21,15 +21,18 @@ def __init__(self, connection_pool, ignore_subscribe_messages=False): self.connection_pool = connection_pool self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None - # we need to know the encoding options for this connection in order - # to lookup channel and pattern names for callback handlers. - conn = connection_pool.get_connection("pubsub") + self.reset() + + async def _ensure_encoding(self): + if hasattr(self, "encoding"): + return + + conn = await self.connection_pool.get_connection("pubsub") try: self.encoding = conn.encoding self.decode_responses = conn.decode_responses finally: - connection_pool.release(conn) - self.reset() + self.connection_pool.release(conn) def __del__(self): try: @@ -54,6 +57,7 @@ def close(self): async def on_connect(self, connection): """Re-subscribe to any channels and patterns previously subscribed to""" + if self.channels: channels = {} @@ -98,10 +102,10 @@ async def execute_command(self, *args, **kwargs): # legitimate message off the stack if the connection is already # subscribed to one or more channels + await self._ensure_encoding() + if self.connection is None: - self.connection = self.connection_pool.get_connection() - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected + self.connection = await self.connection_pool.get_connection() self.connection.register_connect_callback(self.on_connect) connection = self.connection await self._execute(connection, connection.send_command, *args) @@ -159,6 +163,7 @@ async def psubscribe(self, *args, **kwargs): received on that pattern rather than producing a message via ``listen()``. """ + await self._ensure_encoding() if args: args = list_or_args(args[0], args[1:]) @@ -180,6 +185,7 @@ async def punsubscribe(self, *args): Unsubscribes from the supplied patterns. If empy, unsubscribe from all patterns. """ + await self._ensure_encoding() if args: args = list_or_args(args[0], args[1:]) @@ -195,6 +201,8 @@ async def subscribe(self, *args, **kwargs): ``get_message()``. """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) new_channels = {} @@ -216,6 +224,8 @@ async def unsubscribe(self, *args): all channels """ + await self._ensure_encoding() + if args: args = list_or_args(args[0], args[1:]) @@ -385,7 +395,7 @@ async def execute_command(self, *args, **kwargs): await self.connection_pool.initialize() if self.connection is None: - self.connection = self.connection_pool.get_connection( + self.connection = await self.connection_pool.get_connection( "pubsub", channel=args[1], ) diff --git a/docs/source/api.rst b/docs/source/api.rst index ac869117..62e214a5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -21,6 +21,10 @@ Connection Classes .. autoclass:: Connection .. autoclass:: UnixDomainSocketConnection .. autoclass:: ClusterConnection + +Connection Pools +^^^^^^^^^^^^^^^^ +.. autoclass:: BlockingConnectionPool .. autoclass:: ConnectionPool .. autoclass:: ClusterConnectionPool diff --git a/tests/client/test_connection_pool.py b/tests/client/test_connection_pool.py index 40686d15..f40d6ecd 100644 --- a/tests/client/test_connection_pool.py +++ b/tests/client/test_connection_pool.py @@ -38,12 +38,14 @@ def get_pool( max_connections=max_connections, **connection_kwargs ) + return pool - def test_connection_creation(self): + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_connection_creation(self): connection_kwargs = {"foo": "bar", "biz": "baz"} pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = pool.get_connection() + connection = await pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -53,18 +55,20 @@ def test_multiple_connections(self): c2 = pool.get_connection() assert c1 != c2 - def test_max_connections(self): + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_max_connections(self): pool = self.get_pool(max_connections=2) - pool.get_connection() - pool.get_connection() + await pool.get_connection() + await pool.get_connection() with pytest.raises(ConnectionError): - pool.get_connection() + await pool.get_connection() - def test_reuse_previously_released_connection(self): + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_reuse_previously_released_connection(self): pool = self.get_pool() - c1 = pool.get_connection() + c1 = await pool.get_connection() pool.release(c1) - c2 = pool.get_connection() + c2 = await pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -105,6 +109,108 @@ async def test_connection_idle_check(self, event_loop): assert conn._writer is None and conn._reader is None +class TestBlockingConnectionPool: + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=DummyConnection, + timeout=None, + ): + connection_kwargs = connection_kwargs or {} + pool = coredis.BlockingConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs + ) + + return pool + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_connection_creation(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection() + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_multiple_connections(self): + pool = self.get_pool() + c1 = await pool.get_connection() + c2 = await pool.get_connection() + assert c1 != c2 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_max_connections_timeout(self): + pool = self.get_pool(max_connections=2, timeout=0.1) + await pool.get_connection() + await pool.get_connection() + with pytest.raises(ConnectionError): + await pool.get_connection() + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_max_connections_no_timeout(self): + pool = self.get_pool(max_connections=2) + await pool.get_connection() + released_conn = await pool.get_connection() + + def releaser(): + pool.release(released_conn) + + loop = asyncio.get_running_loop() + loop.call_later(0.2, releaser) + new_conn = await pool.get_connection() + assert new_conn == released_conn + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_reuse_previously_released_connection(self): + pool = self.get_pool() + c1 = await pool.get_connection() + pool.release(c1) + c2 = await pool.get_connection() + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = {"host": "localhost", "port": 6379, "db": 1} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=coredis.Connection + ) + expected = "BlockingConnectionPool>" + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {"path": "/abc", "db": 1} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=coredis.UnixDomainSocketConnection, + ) + expected = "BlockingConnectionPool>" + assert repr(pool) == expected + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_connection_idle_check(self, event_loop): + rs = coredis.StrictRedis( + host="127.0.0.1", + port=6379, + db=0, + max_idle_time=0.2, + idle_check_interval=0.1, + ) + await rs.info() + assert len(rs.connection_pool._in_use_connections) == 0 + conn = await rs.connection_pool.get_connection() + last_active_at = conn.last_active_at + rs.connection_pool.release(conn) + await asyncio.sleep(0.3) + assert len(rs.connection_pool._in_use_connections) == 0 + assert last_active_at == conn.last_active_at + assert conn._writer is None and conn._reader is None + new_conn = rs.connection_pool.get_connection() + assert conn != new_conn + + class TestConnectionPoolURLParsing: def test_defaults(self): pool = coredis.ConnectionPool.from_url("redis://localhost")