From 4923604f36ba79655322d6d96b0fb4d5db6a575d Mon Sep 17 00:00:00 2001 From: KShivendu Date: Thu, 24 Aug 2023 16:43:31 +0530 Subject: [PATCH 01/12] feat: Add 429 rate limting from SaaS --- supertokens_python/constants.py | 1 + supertokens_python/querier.py | 32 ++++++- tests/test_querier.py | 150 ++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 tests/test_querier.py diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 070cfb626..b192cfda4 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -29,3 +29,4 @@ API_VERSION_HEADER = "cdi-version" DASHBOARD_VERSION = "0.6" HUNDRED_YEARS_IN_MS = 3153600000000 +RATE_LIMIT_STATUS_CODE = 429 diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index e3da29362..79f6ae3dc 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -13,9 +13,11 @@ # under the License. from __future__ import annotations +import asyncio + from json import JSONDecodeError from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from httpx import AsyncClient, ConnectTimeout, NetworkError, Response @@ -25,6 +27,7 @@ API_VERSION_HEADER, RID_KEY_HEADER, SUPPORTED_CDI_VERSIONS, + RATE_LIMIT_STATUS_CODE, ) from .normalised_url_path import NormalisedURLPath @@ -222,6 +225,7 @@ async def __send_request_helper( method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -238,6 +242,14 @@ async def __send_request_helper( Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -247,6 +259,20 @@ async def __send_request_helper( ): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -264,9 +290,9 @@ async def __send_request_helper( except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) diff --git a/tests/test_querier.py b/tests/test_querier.py new file mode 100644 index 000000000..86d12aaf0 --- /dev/null +++ b/tests/test_querier.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from pytest import mark +from supertokens_python.recipe import ( + session, + emailpassword, + emailverification, + dashboard, +) +import asyncio +import respx +import httpx +from supertokens_python import init, SupertokensConfig +from supertokens_python.querier import Querier, NormalisedURLPath + +from tests.utils import get_st_init_args +from tests.utils import ( + setup_function, + teardown_function, + start_st, +) + +_ = setup_function +_ = teardown_function + +pytestmark = mark.asyncio +respx_mock = respx.MockRouter + + +async def test_network_call_is_retried_as_expected(): + # Test that network call is retried properly + # Test that rate limiting errors are thrown back to the user + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + emailverification.init(mode="OPTIONAL"), + dashboard.init(), + ] + ) + args["supertokens_config"] = SupertokensConfig("http://localhost:6789") + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + api2_call_count = 0 + + def api2_side_effect(_: httpx.Request): + nonlocal api2_call_count + api2_call_count += 1 + + if api2_call_count == 3: + return httpx.Response(200) + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api1 = mocker.get("http://localhost:6789/api1").mock( + httpx.Response(429, json={"status": "RATE_ERROR"}) + ) + api2 = mocker.get("http://localhost:6789/api2").mock( + side_effect=api2_side_effect + ) + api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200)) + + try: + await q.send_get_request(NormalisedURLPath("/api1"), {}) + except Exception as e: + if "with status code: 429" in str( + e + ) and 'message: {"status": "RATE_ERROR"}' in str(e): + pass + else: + raise e + + await q.send_get_request(NormalisedURLPath("/api2"), {}) + await q.send_get_request(NormalisedURLPath("/api3"), {}) + + # 1 initial request + 5 retries + assert api1.call_count == 6 + # 2 403 and 1 200 + assert api2.call_count == 3 + # 200 in the first attempt + assert api3.call_count == 1 + + +async def test_parallel_calls_have_independent_counters(): + args = get_st_init_args( + [ + session.init(), + emailpassword.init(), + emailverification.init(mode="OPTIONAL"), + dashboard.init(), + ] + ) + init(**args) # type: ignore + start_st() + + Querier.api_version = "3.0" + q = Querier.get_instance() + + call_count1 = 0 + call_count2 = 0 + + def api_side_effect(r: httpx.Request): + nonlocal call_count1, call_count2 + + id_ = int(r.url.params.get("id")) + if id_ == 1: + call_count1 += 1 + elif id_ == 2: + call_count2 += 1 + + return httpx.Response(429, json={}) + + with respx_mock() as mocker: + api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect) + + async def call_api(id_: int): + try: + await q.send_get_request(NormalisedURLPath("/api"), {"id": id_}) + except Exception as e: + if "with status code: 429" in str(e): + pass + else: + raise e + + _ = await asyncio.gather( + call_api(1), + call_api(2), + ) + + # 1 initial request + 5 retries + assert call_count1 == 6 + assert call_count2 == 6 + + assert api.call_count == 12 From 4d67c0448f04686ad990b7a925e7578fe16875e5 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Mon, 28 Aug 2023 17:25:54 +0530 Subject: [PATCH 02/12] feat: Add retry logic for 429 from SaaS instances --- CHANGELOG.md | 6 +++++- setup.py | 2 +- supertokens_python/constants.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d4e53a53..4bc719ed1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.14.9] - 2023-09-28 + +- Add logic to retry network calls if the core returns status 429 + ## [0.14.8] - 2023-07-07 ## Fixes @@ -148,7 +152,7 @@ if (accessTokenPayload.sub !== undefined) { ```python from supertokens_python.recipe.session.interfaces import SessionContainer -session: SessionContainer = ... +session: SessionContainer = ... access_token_payload = await session.get_access_token_payload() if access_token_payload.get('sub') is not None: diff --git a/setup.py b/setup.py index 7e83f47c4..0fff1a728 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.14.8", + version="0.14.9", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index b192cfda4..32e942570 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.8" +VERSION = "0.14.9" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" From 025f6eb9b0c576513ddcedd1a9404622a465b9d9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 11:08:51 +0530 Subject: [PATCH 03/12] mods to add dev tag --- addDevTag | 6 ------ 1 file changed, 6 deletions(-) diff --git a/addDevTag b/addDevTag index 1fd4f670e..cdde45c47 100755 --- a/addDevTag +++ b/addDevTag @@ -1,11 +1,5 @@ #!/bin/bash -# check if we need to merge master into this branch------------ -if [[ $(git log origin/master ^HEAD) ]]; then - echo "You need to merge master into this branch. Exiting" - exit 1 -fi - # get version------------ version=`cat setup.py | grep -e 'version='` while IFS='"' read -ra ADDR; do From 33a5d6baa351992c4a2a9a7fe40ce37e2c2fb069 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 11:09:49 +0530 Subject: [PATCH 04/12] adding dev-v0.14.9 tag to this commit to ensure building --- html/supertokens_python/constants.html | 5 ++- html/supertokens_python/querier.html | 59 +++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/html/supertokens_python/constants.html b/html/supertokens_python/constants.html index c022003ff..dc1308bec 100644 --- a/html/supertokens_python/constants.html +++ b/html/supertokens_python/constants.html @@ -42,7 +42,7 @@

Module supertokens_python.constants

from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.8" +VERSION = "0.14.9" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" @@ -56,7 +56,8 @@

Module supertokens_python.constants

API_VERSION = "/apiversion" API_VERSION_HEADER = "cdi-version" DASHBOARD_VERSION = "0.6" -HUNDRED_YEARS_IN_MS = 3153600000000 +HUNDRED_YEARS_IN_MS = 3153600000000 +RATE_LIMIT_STATUS_CODE = 429
diff --git a/html/supertokens_python/querier.html b/html/supertokens_python/querier.html index 8636020dc..f95a55798 100644 --- a/html/supertokens_python/querier.html +++ b/html/supertokens_python/querier.html @@ -41,9 +41,11 @@

Module supertokens_python.querier

# under the License. from __future__ import annotations +import asyncio + from json import JSONDecodeError from os import environ -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from httpx import AsyncClient, ConnectTimeout, NetworkError, Response @@ -53,6 +55,7 @@

Module supertokens_python.querier

API_VERSION_HEADER, RID_KEY_HEADER, SUPPORTED_CDI_VERSIONS, + RATE_LIMIT_STATUS_CODE, ) from .normalised_url_path import NormalisedURLPath @@ -250,6 +253,7 @@

Module supertokens_python.querier

method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -266,6 +270,14 @@

Module supertokens_python.querier

Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -275,6 +287,20 @@

Module supertokens_python.querier

): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -292,9 +318,9 @@

Module supertokens_python.querier

except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) @@ -503,6 +529,7 @@

Classes

method: str, http_function: Callable[[str], Awaitable[Response]], no_of_tries: int, + retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: if no_of_tries == 0: raise_general_exception("No SuperTokens core available to query") @@ -519,6 +546,14 @@

Classes

Querier.__last_tried_index %= len(self.__hosts) url = current_host + path.get_as_string_dangerous() + max_retries = 5 + + if retry_info_map is None: + retry_info_map = {} + + if retry_info_map.get(url) is None: + retry_info_map[url] = max_retries + ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) @@ -528,6 +563,20 @@

Classes

): Querier.__hosts_alive_for_testing.add(current_host) + if response.status_code == RATE_LIMIT_STATUS_CODE: + retries_left = retry_info_map[url] + + if retries_left > 0: + retry_info_map[url] = retries_left - 1 + + attempts_made = max_retries - retries_left + delay = (10 + attempts_made * 250) / 1000 + + await asyncio.sleep(delay) + return await self.__send_request_helper( + path, method, http_function, no_of_tries, retry_info_map + ) + if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore raise_general_exception( "SuperTokens core threw an error for a " @@ -545,9 +594,9 @@

Classes

except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout): + except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( - path, method, http_function, no_of_tries - 1 + path, method, http_function, no_of_tries - 1, retry_info_map ) except Exception as e: raise_general_exception(e) From 54595d9fffe375eb1577e66302d750230ea12ca0 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 29 Aug 2023 13:53:32 +0530 Subject: [PATCH 05/12] test: Fix failing test for the 0.14 patch release --- tests/test_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_session.py b/tests/test_session.py index cc8441de4..469869307 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -794,7 +794,7 @@ async def test_expose_access_token_to_frontend_in_cookie_based_auth( assert response.status_code == 200 assert len(response.headers["st-access-token"]) > 0 - reset(stop_core=False) + reset() args = get_st_init_args([session.init(expose_access_token_to_frontend_in_cookie_based_auth=False, get_token_transfer_method=lambda *_: "cookie")]) # type: ignore init(**args) # type: ignore From 8fad5bcdd080d29591141888eef07716df627c33 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 29 Aug 2023 13:56:22 +0530 Subject: [PATCH 06/12] adding dev-v0.14.9 tag to this commit to ensure building From b4b0c4a8ad52838c7f39a5d2167e09aedc8580ac Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:14:38 +0530 Subject: [PATCH 07/12] adds nestjs patch --- supertokens_python/async_to_sync_wrapper.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 4a56ea31b..a8178f0e2 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import nest_asyncio # type: ignore import asyncio from typing import Any, Coroutine, TypeVar @@ -24,10 +25,12 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() + nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) def sync(co: Coroutine[Any, Any, _T]) -> _T: check_event_loop() loop = asyncio.get_event_loop() + nest_asyncio.apply(loop) # type: ignore return loop.run_until_complete(co) From c1215c8dfc96d805f6a4aa055c64f26603faf86d Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:24:18 +0530 Subject: [PATCH 08/12] bumps version --- CHANGELOG.md | 4 ++++ setup.py | 2 +- supertokens_python/constants.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bc719ed1..6ef325495 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.14.10] - 2023-09-31 + +- Uses nest_asyncio patch in event loop - sync to async + ## [0.14.9] - 2023-09-28 - Add logic to retry network calls if the core returns status 429 diff --git a/setup.py b/setup.py index 0fff1a728..bd1100a2f 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.14.9", + version="0.14.10", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 32e942570..b8886e6f2 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.9" +VERSION = "0.14.10" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" From 4d20336929ce23689f27e39992d5316522c96d15 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 17:35:32 +0530 Subject: [PATCH 09/12] adds missing dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index bd1100a2f..333c5cf59 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,7 @@ "phonenumbers==8.12.48", "twilio==7.9.1", "aiosmtplib==1.1.6", + "nest-asyncio==1.5.1", ], python_requires=">=3.7", include_package_data=True, From d900f065687445c058131b859b00b34143e3248b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 20:10:19 +0530 Subject: [PATCH 10/12] more changes --- supertokens_python/async_to_sync_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index a8178f0e2..0e9286ee7 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -25,7 +25,6 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() - nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) From dbca4da04d3648143db567e6d327c6423593bdd9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 31 Aug 2023 21:04:16 +0530 Subject: [PATCH 11/12] more changes --- supertokens_python/async_to_sync_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 0e9286ee7..9c623bf51 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -25,11 +25,11 @@ def check_event_loop(): except RuntimeError as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() + nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) def sync(co: Coroutine[Any, Any, _T]) -> _T: check_event_loop() loop = asyncio.get_event_loop() - nest_asyncio.apply(loop) # type: ignore return loop.run_until_complete(co) From b1e3d44173cf382e6ffc32e1ae6b0ee5422f560f Mon Sep 17 00:00:00 2001 From: KShivendu Date: Tue, 19 Sep 2023 11:18:11 +0530 Subject: [PATCH 12/12] fix: Async lib not found error --- CHANGELOG.md | 2 +- supertokens_python/async_to_sync_wrapper.py | 11 +-- supertokens_python/constants.py | 2 +- supertokens_python/querier.py | 87 +++++++++++++++------ supertokens_python/utils.py | 5 +- 5 files changed, 71 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca30a82d3..f380d3c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -541,7 +541,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.14.10] - 2023-09-31 -- Uses nest_asyncio patch in event loop - sync to async +- Uses `nest_asyncio` patch in event loop - sync to async ## [0.14.9] - 2023-09-28 diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 9c623bf51..0e3d27486 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -19,17 +19,18 @@ _T = TypeVar("_T") -def check_event_loop(): +def create_or_get_event_loop() -> asyncio.AbstractEventLoop: try: - asyncio.get_event_loop() - except RuntimeError as ex: + return asyncio.get_event_loop() + except Exception as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() nest_asyncio.apply(loop) # type: ignore asyncio.set_event_loop(loop) + return loop + raise ex def sync(co: Coroutine[Any, Any, _T]) -> _T: - check_event_loop() - loop = asyncio.get_event_loop() + loop = create_or_get_event_loop() return loop.run_until_complete(co) diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 075ff9309..292277329 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.16.0" +VERSION = "0.16.1" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 79f6ae3dc..db520f11a 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -39,6 +39,8 @@ from .exceptions import raise_general_exception from .process_state import AllowedProcessStates, ProcessState from .utils import find_max_version, is_4xx_error, is_5xx_error +from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop +from sniffio import AsyncLibraryNotFoundError class Querier: @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing(): raise_general_exception("calling testing function in non testing env") return Querier.__hosts_alive_for_testing + async def api_request( + self, + url: str, + method: str, + attempts_remaining: int, + *args: Any, + **kwargs: Any, + ) -> Response: + if attempts_remaining == 0: + raise_general_exception("Retry request failed") + + try: + async with AsyncClient() as client: + if method == "GET": + return await client.get(url, *args, **kwargs) # type: ignore + if method == "POST": + return await client.post(url, *args, **kwargs) # type: ignore + if method == "PUT": + return await client.put(url, *args, **kwargs) # type: ignore + if method == "DELETE": + return await client.delete(url, *args, **kwargs) # type: ignore + raise Exception("Shouldn't come here") + except AsyncLibraryNotFoundError: + # Retry + loop = create_or_get_event_loop() + return loop.run_until_complete( + self.api_request(url, method, attempts_remaining - 1, *args, **kwargs) + ) + async def get_api_version(self): if Querier.api_version is not None: return Querier.api_version @@ -79,12 +110,11 @@ async def get_api_version(self): AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION ) - async def f(url: str) -> Response: + async def f(url: str, method: str) -> Response: headers = {} if Querier.__api_key is not None: headers = {API_KEY_HEADER: Querier.__api_key} - async with AsyncClient() as client: - return await client.get(url, headers=headers) # type:ignore + return await self.api_request(url, method, 2, headers=headers) response = await self.__send_request_helper( NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts) @@ -134,13 +164,14 @@ async def send_get_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.get( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "GET", f, len(self.__hosts)) @@ -163,9 +194,14 @@ async def send_post_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.post(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + json=data, + ) return await self.__send_request_helper(path, "POST", f, len(self.__hosts)) @@ -175,13 +211,14 @@ async def send_delete_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.delete( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts)) @@ -194,9 +231,8 @@ async def send_put_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.put(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request(url, method, 2, headers=headers, json=data) return await self.__send_request_helper(path, "PUT", f, len(self.__hosts)) @@ -223,7 +259,7 @@ async def __send_request_helper( self, path: NormalisedURLPath, method: str, - http_function: Callable[[str], Awaitable[Response]], + http_function: Callable[[str, str], Awaitable[Response]], no_of_tries: int, retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: @@ -253,7 +289,7 @@ async def __send_request_helper( ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) - response = await http_function(url) + response = await http_function(url, method) if ("SUPERTOKENS_ENV" in environ) and ( environ["SUPERTOKENS_ENV"] == "testing" ): @@ -289,7 +325,6 @@ async def __send_request_helper( return response.json() except JSONDecodeError: return response.text - except (ConnectionError, NetworkError, ConnectTimeout) as _: return await self.__send_request_helper( path, method, http_function, no_of_tries - 1, retry_info_map diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index a79d182c1..d504b8fe3 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -39,7 +39,7 @@ from httpx import HTTPStatusError, Response from tldextract import extract # type: ignore -from supertokens_python.async_to_sync_wrapper import check_event_loop +from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework @@ -212,8 +212,7 @@ def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]): if real_mode == "wsgi": asyncio.run(func()) else: - check_event_loop() - loop = asyncio.get_event_loop() + loop = create_or_get_event_loop() loop.create_task(func())