Skip to content

Commit

Permalink
Add option to add client_name
Browse files Browse the repository at this point in the history
Derived from NoneGG/aredis#157
  • Loading branch information
alisaifee committed Jan 16, 2022
1 parent 27e015d commit b168e3d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions coredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
retry_on_timeout=False,
max_idle_time=0,
idle_check_interval=1,
client_name=None,
loop=None,
**kwargs
):
Expand All @@ -172,6 +173,7 @@ def __init__(
"decode_responses": decode_responses,
"max_idle_time": max_idle_time,
"idle_check_interval": idle_check_interval,
"client_name": client_name,
"loop": loop,
}
# based on input, setup appropriate connection args
Expand Down
13 changes: 12 additions & 1 deletion coredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def __init__(
encoding="utf-8",
decode_responses=False,
*,
loop=None
client_name=None,
loop=None,
):
self._parser = parser_class(reader_read_size)
self._stream_timeout = stream_timeout
Expand All @@ -391,6 +392,7 @@ def __init__(
self.encoding = encoding
self.decode_responses = decode_responses
self.loop = loop
self.client_name = client_name
# flag to show if a connection is waiting for response
self.awaiting_response = False
self.last_active_at = time.time()
Expand Down Expand Up @@ -454,6 +456,11 @@ async def on_connect(self):
await self.send_command("SELECT", self.db)
if nativestr(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")

if self.client_name is not None:
await self.send_command('CLIENT SETNAME', self.client_name)
if nativestr(await self.read_response()) != "OK":
raise ConnectionError(f"Failed to set client name: {self.client_name}")
self.last_active_at = time.time()

async def read_response(self):
Expand Down Expand Up @@ -598,6 +605,7 @@ def __init__(
socket_keepalive=None,
socket_keepalive_options=None,
*,
client_name=None,
loop=None
):
super(Connection, self).__init__(
Expand All @@ -607,6 +615,7 @@ def __init__(
reader_read_size,
encoding,
decode_responses,
client_name=client_name,
loop=loop,
)
self.host = host
Expand Down Expand Up @@ -667,6 +676,7 @@ def __init__(
encoding="utf-8",
decode_responses=False,
*,
client_name=None,
loop=None
):
super(UnixDomainSocketConnection, self).__init__(
Expand All @@ -676,6 +686,7 @@ def __init__(
reader_read_size,
encoding,
decode_responses,
client_name=client_name,
loop=loop,
)
self.path = path
Expand Down

0 comments on commit b168e3d

Please sign in to comment.