From 51763a18bc2858ce85db4dd6c8cba95ab1ff831a Mon Sep 17 00:00:00 2001 From: Dov Reshef Date: Tue, 4 Aug 2020 09:52:56 +0300 Subject: [PATCH] Add the option to set the client name --- .gitignore | 1 + aredis/client.py | 2 ++ aredis/connection.py | 16 +++++++++++----- tests/client/conftest.py | 2 +- tests/client/test_commands.py | 2 +- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 344dde3f..a6b36ced 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vagrant/.vagrant .vscode/ *.iml .pytest_cache/ +*.so diff --git a/aredis/client.py b/aredis/client.py index 91973db5..ed92c1a5 100644 --- a/aredis/client.py +++ b/aredis/client.py @@ -100,6 +100,7 @@ def __init__(self, host='localhost', port=6379, ssl_cert_reqs=None, ssl_ca_certs=None, max_connections=None, retry_on_timeout=False, max_idle_time=0, idle_check_interval=1, + client_name=None, loop=None, **kwargs): if not connection_pool: kwargs = { @@ -113,6 +114,7 @@ def __init__(self, host='localhost', port=6379, 'decode_responses': decode_responses, 'max_idle_time': max_idle_time, 'idle_check_interval': idle_check_interval, + 'client_name': client_name, 'loop': loop } # based on input, setup appropriate connection args diff --git a/aredis/connection.py b/aredis/connection.py index 6dc79050..4c874754 100755 --- a/aredis/connection.py +++ b/aredis/connection.py @@ -367,7 +367,7 @@ class BaseConnection: def __init__(self, retry_on_timeout=False, stream_timeout=None, parser_class=DefaultParser, reader_read_size=65535, encoding='utf-8', decode_responses=False, - *, loop=None): + *, client_name=None, loop=None): self._parser = parser_class(reader_read_size) self._stream_timeout = stream_timeout self._reader = None @@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None, self.encoding = encoding self.decode_responses = decode_responses self.loop = loop + self.client_name = client_name # flag to show if a connection is waiting for response self.awaiting_response = False self.last_active_at = time.time() @@ -444,6 +445,11 @@ async def on_connect(self): await self.send_command('SELECT', self.db) if nativestr(await self.read_response()) != 'OK': raise ConnectionError('Invalid Database') + + if self.client_name is not None: + await self.send_command('CLIENT SETNAME', self.client_name) + if nativestr(await self.read_response()) != 'OK': + raise ConnectionError('Failed to set client name: {}'.format(self.client_name)) self.last_active_at = time.time() async def read_response(self): @@ -573,11 +579,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None, db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None, ssl_context=None, parser_class=DefaultParser, reader_read_size=65535, encoding='utf-8', decode_responses=False, socket_keepalive=None, - socket_keepalive_options=None, *, loop=None): + socket_keepalive_options=None, *, loop=None, client_name=None): super(Connection, self).__init__(retry_on_timeout, stream_timeout, parser_class, reader_read_size, encoding, decode_responses, - loop=loop) + client_name=client_name, loop=loop) self.host = host self.port = port self.password = password @@ -626,11 +632,11 @@ class UnixDomainSocketConnection(BaseConnection): def __init__(self, path='', password=None, db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None, ssl_context=None, parser_class=DefaultParser, reader_read_size=65535, - encoding='utf-8', decode_responses=False, *, loop=None): + encoding='utf-8', decode_responses=False, *, loop=None, client_name=None): super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout, parser_class, reader_read_size, encoding, decode_responses, - loop=loop) + client_name=client_name, loop=loop) self.path = path self.db = db self.password = password diff --git a/tests/client/conftest.py b/tests/client/conftest.py index afc4d3ae..e58aff6a 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -37,7 +37,7 @@ def skip_python_vsersion_lt(min_version): @pytest.fixture() def r(event_loop): - return aredis.StrictRedis(loop=event_loop) + return aredis.StrictRedis(client_name='test', loop=event_loop) class AsyncMock(Mock): diff --git a/tests/client/test_commands.py b/tests/client/test_commands.py index 6c47c53d..0097461e 100644 --- a/tests/client/test_commands.py +++ b/tests/client/test_commands.py @@ -64,7 +64,7 @@ async def test_client_list_after_client_setname(self, r): @skip_if_server_version_lt('2.6.9') @pytest.mark.asyncio(forbid_global_loop=True) async def test_client_getname(self, r): - assert await r.client_getname() is None + assert await r.client_getname() == 'test' @skip_if_server_version_lt('2.6.9') @pytest.mark.asyncio(forbid_global_loop=True)