From f34935cba01c7577b1200d2a6e23ea47efb9a524 Mon Sep 17 00:00:00 2001 From: Vasyl Dizhak Date: Sat, 15 Jun 2024 14:46:08 +0200 Subject: [PATCH] Fix review comments #597 --- django_redis/client/default.py | 31 +++++++++++++++++-------------- django_redis/client/sharded.py | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 6f51a4cd..3219f7c9 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -517,6 +517,17 @@ def encode(self, value: EncodableT) -> Union[bytes, int]: return value + def _decode_iterable_result( + self, result: Any, covert_to_set: bool = True + ) -> Union[List[Any], None, Any]: + if result is None: + return None + if isinstance(result, list): + if covert_to_set: + return {self.decode(value) for value in result} + return [self.decode(value) for value in result] + return self.decode(result) + def get_many( self, keys: Iterable[KeyT], @@ -828,7 +839,7 @@ def sdiff( *keys: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, - ) -> Set: + ) -> Set[Any]: if client is None: client = self.get_client(write=False) @@ -855,7 +866,7 @@ def sinter( *keys: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, - ) -> Set: + ) -> Set[Any]: if client is None: client = self.get_client(write=False) @@ -910,7 +921,7 @@ def smembers( key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, - ) -> Set: + ) -> Set[Any]: if client is None: client = self.get_client(write=False) @@ -945,11 +956,7 @@ def spop( nkey = self.make_key(key, version=version) result = client.spop(nkey, count) - if result is None: - return None - if isinstance(result, list): - return {self.decode(value) for value in result} - return self.decode(result) + return self._decode_iterable_result(result) def srandmember( self, @@ -963,11 +970,7 @@ def srandmember( key = self.make_key(key, version=version) result = client.srandmember(key, count) - if result is None: - return None - if isinstance(result, list): - return [self.decode(value) for value in result] - return self.decode(result) + return self._decode_iterable_result(result, covert_to_set=False) def srem( self, @@ -1035,7 +1038,7 @@ def sunion( *keys: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, - ) -> Set: + ) -> Set[Any]: if client is None: client = self.get_client(write=False) diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index 871b1df1..5e2eec90 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -367,7 +367,7 @@ def smembers( key: KeyT, version: Optional[int] = None, client: Optional[Redis] = None, - ) -> Set: + ) -> Set[Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key)