From eb7c551923efded0d0a28ce15903785fb6f11031 Mon Sep 17 00:00:00 2001 From: abersheeran Date: Wed, 23 Jun 2021 17:04:07 +0800 Subject: [PATCH] more tests --- ratelimit/backends/slidingredis.py | 11 ++++++----- ratelimit/rule.py | 2 +- tests/backends/test_redis.py | 2 ++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ratelimit/backends/slidingredis.py b/ratelimit/backends/slidingredis.py index 2752631..a043fdd 100644 --- a/ratelimit/backends/slidingredis.py +++ b/ratelimit/backends/slidingredis.py @@ -53,7 +53,7 @@ def __init__( ) self.sliding_function = self._redis.register_script(SLIDING_WINDOW_SCRIPT) - async def get_limits(self, path: str, user: str, rule: Rule) -> bool: + async def get_limits(self, path: str, user: str, rule: Rule) -> dict: epoch = time.time() ruleset = rule.ruleset(path, user) r = await self.sliding_function.execute( @@ -69,10 +69,10 @@ async def get_limits(self, path: str, user: str, rule: Rule) -> bool: async def set_block_time(self, user: str, block_time: int) -> None: await self._redis.set(f"blocking:{user}", True, block_time) - async def is_blocking(self, user: str) -> bool: - return bool(await self._redis.get(f"blocking:{user}")) + async def is_blocking(self, user: str) -> int: + return int(await self._redis.ttl(f"blocking:{user}")) - async def retry_after(self, path: str, user: str, rule: Rule) -> bool: + async def retry_after(self, path: str, user: str, rule: Rule) -> int: block_time = await self.is_blocking(user) if block_time > 0: return block_time @@ -82,5 +82,6 @@ async def retry_after(self, path: str, user: str, rule: Rule) -> bool: if retry_after > 0 and rule.block_time: await self.set_block_time(user, rule.block_time) + retry_after = rule.block_time - return retry_after + return round(retry_after) diff --git a/ratelimit/rule.py b/ratelimit/rule.py index 0bd7cf5..dceff60 100644 --- a/ratelimit/rule.py +++ b/ratelimit/rule.py @@ -34,4 +34,4 @@ def ruleset(self, path: str, user: str) -> Dict[str, Tuple[int, int]]: "month": 31 * 24 * 60 * 60, } -RULENAMES: Tuple[str] = ("second", "minute", "hour", "day", "month") +RULENAMES: Tuple[str, ...] = ("second", "minute", "hour", "day", "month") diff --git a/tests/backends/test_redis.py b/tests/backends/test_redis.py index cf90ca5..91044ad 100644 --- a/tests/backends/test_redis.py +++ b/tests/backends/test_redis.py @@ -95,6 +95,7 @@ async def test_redis(redisbackend): "/second_limit", headers={"user": "user", "group": "default"} ) assert response.status_code == 429 + assert response.headers["retry-after"] == "1" response = await client.get( "/second_limit", headers={"user": "admin-user", "group": "admin"} @@ -132,6 +133,7 @@ async def test_redis(redisbackend): "/block", headers={"user": "user", "group": "default"} ) assert response.status_code == 429 + assert response.headers["retry-after"] == "5" await asyncio.sleep(1)