Skip to content

Commit

Permalink
Fix review comments #597
Browse files Browse the repository at this point in the history
  • Loading branch information
rootart committed Jun 15, 2024
1 parent 0bb4f32 commit f34935c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
31 changes: 17 additions & 14 deletions django_redis/client/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,17 @@ def encode(self, value: EncodableT) -> Union[bytes, int]:

return value

def _decode_iterable_result(
self, result: Any, covert_to_set: bool = True
) -> Union[List[Any], None, Any]:
if result is None:
return None
if isinstance(result, list):
if covert_to_set:
return {self.decode(value) for value in result}
return [self.decode(value) for value in result]
return self.decode(result)

def get_many(
self,
keys: Iterable[KeyT],
Expand Down Expand Up @@ -828,7 +839,7 @@ def sdiff(
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set:
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

Expand All @@ -855,7 +866,7 @@ def sinter(
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set:
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

Expand Down Expand Up @@ -910,7 +921,7 @@ def smembers(
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set:
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

Expand Down Expand Up @@ -945,11 +956,7 @@ def spop(

nkey = self.make_key(key, version=version)
result = client.spop(nkey, count)
if result is None:
return None
if isinstance(result, list):
return {self.decode(value) for value in result}
return self.decode(result)
return self._decode_iterable_result(result)

def srandmember(
self,
Expand All @@ -963,11 +970,7 @@ def srandmember(

key = self.make_key(key, version=version)
result = client.srandmember(key, count)
if result is None:
return None
if isinstance(result, list):
return [self.decode(value) for value in result]
return self.decode(result)
return self._decode_iterable_result(result, covert_to_set=False)

def srem(
self,
Expand Down Expand Up @@ -1035,7 +1038,7 @@ def sunion(
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set:
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

Expand Down
2 changes: 1 addition & 1 deletion django_redis/client/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def smembers(
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set:
) -> Set[Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
Expand Down

0 comments on commit f34935c

Please sign in to comment.