diff --git a/coredis/client.py b/coredis/client.py index e5d5990f..e8b44de5 100644 --- a/coredis/client.py +++ b/coredis/client.py @@ -157,6 +157,7 @@ def __init__( retry_on_timeout=False, max_idle_time=0, idle_check_interval=1, + client_name=None, loop=None, **kwargs ): @@ -172,6 +173,7 @@ def __init__( "decode_responses": decode_responses, "max_idle_time": max_idle_time, "idle_check_interval": idle_check_interval, + "client_name": client_name, "loop": loop, } # based on input, setup appropriate connection args diff --git a/coredis/connection.py b/coredis/connection.py index ecf9702c..157cc908 100755 --- a/coredis/connection.py +++ b/coredis/connection.py @@ -376,7 +376,8 @@ def __init__( encoding="utf-8", decode_responses=False, *, - loop=None + client_name=None, + loop=None, ): self._parser = parser_class(reader_read_size) self._stream_timeout = stream_timeout @@ -391,6 +392,7 @@ def __init__( self.encoding = encoding self.decode_responses = decode_responses self.loop = loop + self.client_name = client_name # flag to show if a connection is waiting for response self.awaiting_response = False self.last_active_at = time.time() @@ -454,6 +456,11 @@ async def on_connect(self): await self.send_command("SELECT", self.db) if nativestr(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") + + if self.client_name is not None: + await self.send_command('CLIENT SETNAME', self.client_name) + if nativestr(await self.read_response()) != "OK": + raise ConnectionError(f"Failed to set client name: {self.client_name}") self.last_active_at = time.time() async def read_response(self): @@ -598,6 +605,7 @@ def __init__( socket_keepalive=None, socket_keepalive_options=None, *, + client_name=None, loop=None ): super(Connection, self).__init__( @@ -607,6 +615,7 @@ def __init__( reader_read_size, encoding, decode_responses, + client_name=client_name, loop=loop, ) self.host = host @@ -667,6 +676,7 @@ def __init__( encoding="utf-8", decode_responses=False, *, + client_name=None, loop=None ): super(UnixDomainSocketConnection, self).__init__( @@ -676,6 +686,7 @@ def __init__( reader_read_size, encoding, decode_responses, + client_name=client_name, loop=loop, ) self.path = path