Skip to content

Commit

Permalink
Add the option to set the client name
Browse files Browse the repository at this point in the history
  • Loading branch information
dovreshef committed Aug 4, 2020
1 parent cec3cda commit cf1221a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ vagrant/.vagrant
.vscode/
*.iml
.pytest_cache/
*.so
2 changes: 2 additions & 0 deletions aredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, host='localhost', port=6379,
ssl_cert_reqs=None, ssl_ca_certs=None,
max_connections=None, retry_on_timeout=False,
max_idle_time=0, idle_check_interval=1,
client_name=None,
loop=None, **kwargs):
if not connection_pool:
kwargs = {
Expand All @@ -113,6 +114,7 @@ def __init__(self, host='localhost', port=6379,
'decode_responses': decode_responses,
'max_idle_time': max_idle_time,
'idle_check_interval': idle_check_interval,
'client_name': client_name,
'loop': loop
}
# based on input, setup appropriate connection args
Expand Down
16 changes: 11 additions & 5 deletions aredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class BaseConnection:
def __init__(self, retry_on_timeout=False, stream_timeout=None,
parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False,
*, loop=None):
*, loop=None, client_name=None):
self._parser = parser_class(reader_read_size)
self._stream_timeout = stream_timeout
self._reader = None
Expand All @@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
self.encoding = encoding
self.decode_responses = decode_responses
self.loop = loop
self.client_name = client_name
# flag to show if a connection is waiting for response
self.awaiting_response = False
self.last_active_at = time.time()
Expand Down Expand Up @@ -442,6 +443,11 @@ async def on_connect(self):
await self.send_command('SELECT', self.db)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Invalid Database')

if self.client_name:
await self.send_command('CLIENT SETNAME', self.client_name)
if nativestr(await self.read_response()) != 'OK':
raise ConnectionError('Failed to set client name: {}'.format(self.client_name))
self.last_active_at = time.time()

async def read_response(self):
Expand Down Expand Up @@ -571,11 +577,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, socket_keepalive=None,
socket_keepalive_options=None, *, loop=None):
socket_keepalive_options=None, *, loop=None, client_name=None):
super(Connection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
loop=loop, client_name=client_name)
self.host = host
self.port = port
self.password = password
Expand Down Expand Up @@ -624,11 +630,11 @@ class UnixDomainSocketConnection(BaseConnection):
def __init__(self, path='', password=None,
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
encoding='utf-8', decode_responses=False, *, loop=None):
encoding='utf-8', decode_responses=False, *, loop=None, client_name=None):
super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout,
parser_class, reader_read_size,
encoding, decode_responses,
loop=loop)
loop=loop, client_name=client_name)
self.path = path
self.db = db
self.password = password
Expand Down

0 comments on commit cf1221a

Please sign in to comment.