Skip to content

Commit

Permalink
Add the option to set the client name
Browse files Browse the repository at this point in the history
  • Loading branch information
dovreshef committed Nov 8, 2020
1 parent b46e671 commit 67c92a0
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ vagrant/.vagrant
.vscode/
*.iml
.pytest_cache/
*.so
2 changes: 2 additions & 0 deletions aredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, host='localhost', port=6379,
ssl_cert_reqs=None, ssl_ca_certs=None,
max_connections=None, retry_on_timeout=False,
max_idle_time=0, idle_check_interval=1,
client_name=None,
loop=None, **kwargs):
if not connection_pool:
kwargs = {
Expand All @@ -113,6 +114,7 @@ def __init__(self, host='localhost', port=6379,
'decode_responses': decode_responses,
'max_idle_time': max_idle_time,
'idle_check_interval': idle_check_interval,
'client_name': client_name,
'loop': loop
}
# based on input, setup appropriate connection args
Expand Down
16 changes: 11 additions & 5 deletions aredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class BaseConnection:
def __init__(self, retry_on_timeout=False, stream_timeout=None,
parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False,
*, loop=None):
*, client_name=None, loop=None):
self._parser = parser_class(reader_read_size)
self._stream_timeout = stream_timeout
self._reader = None
Expand All @@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
self.encoding = encoding
self.decode_responses = decode_responses
self.loop = loop
self.client_name = client_name
# flag to show if a connection is waiting for response
self.awaiting_response = False
self.last_active_at = time.time()
Expand Down Expand Up @@ -444,6 +445,11 @@ async def on_connect(self):
await self.send_command('SELECT', self.db)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')

if self.client_name is not None:
await self.send_command('CLIENT SETNAME', self.client_name)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Failed to set client name: {}'.format(self.client_name))
self.last_active_at = time.time()

async def read_response(self):
Expand Down Expand Up @@ -573,11 +579,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, socket_keepalive=None,
socket_keepalive_options=None, *, loop=None):
socket_keepalive_options=None, *, client_name=None, loop=None):
super(Connection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
client_name=client_name, loop=loop)
self.host = host
self.port = port
self.password = password
Expand Down Expand Up @@ -626,11 +632,11 @@ class UnixDomainSocketConnection(BaseConnection):
def __init__(self, path='', password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, *, loop=None):
encoding='utf-8', decode_responses=False, *, client_name=None, loop=None):
super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
client_name=client_name, loop=loop)
self.path = path
self.db = db
self.password = password
Expand Down
2 changes: 1 addition & 1 deletion tests/client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def skip_python_vsersion_lt(min_version):

@pytest.fixture()
def r(event_loop):
return aredis.StrictRedis(loop=event_loop)
return aredis.StrictRedis(client_name='test', loop=event_loop)


class AsyncMock(Mock):
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def test_client_list_after_client_setname(self, r):
@skip_if_server_version_lt('2.6.9')
@pytest.mark.asyncio(forbid_global_loop=True)
async def test_client_getname(self, r):
assert await r.client_getname() is None
assert await r.client_getname() == 'test'

@skip_if_server_version_lt('2.6.9')
@pytest.mark.asyncio(forbid_global_loop=True)
Expand Down

0 comments on commit 67c92a0

Please sign in to comment.