diff --git a/.circleci/config_continue.yml b/.circleci/config_continue.yml index 004004ce1..25ae526f7 100644 --- a/.circleci/config_continue.yml +++ b/.circleci/config_continue.yml @@ -79,6 +79,22 @@ jobs: - run: make with-django2x - run: (cd .circleci/ && ./websiteDjango2x.sh) - slack/status + test-website-flask-nest-asyncio: + docker: + - image: rishabhpoddar/supertokens_python_driver_testing + resource_class: large + environment: + SUPERTOKENS_NEST_ASYNCIO: "1" + steps: + - checkout + - run: update-alternatives --install "/usr/bin/java" "java" "/usr/java/jdk-15.0.1/bin/java" 2 + - run: update-alternatives --install "/usr/bin/javac" "javac" "/usr/java/jdk-15.0.1/bin/javac" 2 + - run: git config --global url."https://github.com/".insteadOf ssh://git@github.com/ + - run: echo "127.0.0.1 localhost.org" >> /etc/hosts + - run: make with-flask + - run: python -m pip install nest-asyncio + - run: (cd .circleci/ && ./websiteFlask.sh) + - slack/status test-authreact-fastapi: docker: - image: rishabhpoddar/supertokens_python_driver_testing diff --git a/CHANGELOG.md b/CHANGELOG.md index 461fea709..4d7ca9e50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.16.2] - 2023-09-20 + +- Allow use of [nest-asyncio](https://pypi.org/project/nest-asyncio/) when env var `SUPERTOKENS_NEST_ASYNCIO=1`. +- Retry Querier request on `AsyncLibraryNotFoundError` + ## [0.16.1] - 2023-09-19 - Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute. diff --git a/setup.py b/setup.py index 2b13aff2e..920610b61 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.16.1", + version="0.16.2", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 79035af08..33b96ea0e 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from typing_extensions import Literal @@ -32,11 +32,16 @@ def init( framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]], - mode: Union[Literal["asgi", "wsgi"], None] = None, - telemetry: Union[bool, None] = None, + mode: Optional[Literal["asgi", "wsgi"]] = None, + telemetry: Optional[bool] = None, ): return Supertokens.init( - app_info, framework, supertokens_config, recipe_list, mode, telemetry + app_info, + framework, + supertokens_config, + recipe_list, + mode, + telemetry, ) diff --git a/supertokens_python/async_to_sync_wrapper.py b/supertokens_python/async_to_sync_wrapper.py index 4a56ea31b..8e019336d 100644 --- a/supertokens_python/async_to_sync_wrapper.py +++ b/supertokens_python/async_to_sync_wrapper.py @@ -14,20 +14,32 @@ import asyncio from typing import Any, Coroutine, TypeVar +from os import getenv _T = TypeVar("_T") -def check_event_loop(): +def nest_asyncio_enabled(): + return getenv("SUPERTOKENS_NEST_ASYNCIO", "") == "1" + + +def create_or_get_event_loop() -> asyncio.AbstractEventLoop: try: - asyncio.get_event_loop() - except RuntimeError as ex: + return asyncio.get_event_loop() + except Exception as ex: if "There is no current event loop in thread" in str(ex): loop = asyncio.new_event_loop() + + if nest_asyncio_enabled(): + import nest_asyncio # type: ignore + + nest_asyncio.apply(loop) # type: ignore + asyncio.set_event_loop(loop) + return loop + raise ex def sync(co: Coroutine[Any, Any, _T]) -> _T: - check_event_loop() - loop = asyncio.get_event_loop() + loop = create_or_get_event_loop() return loop.run_until_complete(co) diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 292277329..fbece17e8 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.16.1" +VERSION = "0.16.2" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 79f6ae3dc..66e1a2dd7 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -39,6 +39,8 @@ from .exceptions import raise_general_exception from .process_state import AllowedProcessStates, ProcessState from .utils import find_max_version, is_4xx_error, is_5xx_error +from sniffio import AsyncLibraryNotFoundError +from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop class Querier: @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing(): raise_general_exception("calling testing function in non testing env") return Querier.__hosts_alive_for_testing + async def api_request( + self, + url: str, + method: str, + attempts_remaining: int, + *args: Any, + **kwargs: Any, + ) -> Response: + if attempts_remaining == 0: + raise_general_exception("Retry request failed") + + try: + async with AsyncClient() as client: + if method == "GET": + return await client.get(url, *args, **kwargs) # type: ignore + if method == "POST": + return await client.post(url, *args, **kwargs) # type: ignore + if method == "PUT": + return await client.put(url, *args, **kwargs) # type: ignore + if method == "DELETE": + return await client.delete(url, *args, **kwargs) # type: ignore + raise Exception("Shouldn't come here") + except AsyncLibraryNotFoundError: + # Retry + loop = create_or_get_event_loop() + return loop.run_until_complete( + self.api_request(url, method, attempts_remaining - 1, *args, **kwargs) + ) + async def get_api_version(self): if Querier.api_version is not None: return Querier.api_version @@ -79,12 +110,11 @@ async def get_api_version(self): AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION ) - async def f(url: str) -> Response: + async def f(url: str, method: str) -> Response: headers = {} if Querier.__api_key is not None: headers = {API_KEY_HEADER: Querier.__api_key} - async with AsyncClient() as client: - return await client.get(url, headers=headers) # type:ignore + return await self.api_request(url, method, 2, headers=headers) response = await self.__send_request_helper( NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts) @@ -134,13 +164,14 @@ async def send_get_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.get( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "GET", f, len(self.__hosts)) @@ -163,9 +194,14 @@ async def send_post_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.post(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + json=data, + ) return await self.__send_request_helper(path, "POST", f, len(self.__hosts)) @@ -175,13 +211,14 @@ async def send_delete_request( if params is None: params = {} - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.delete( # type:ignore - url, - params=params, - headers=await self.__get_headers_with_api_version(path), - ) + async def f(url: str, method: str) -> Response: + return await self.api_request( + url, + method, + 2, + headers=await self.__get_headers_with_api_version(path), + params=params, + ) return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts)) @@ -194,9 +231,8 @@ async def send_put_request( headers = await self.__get_headers_with_api_version(path) headers["content-type"] = "application/json; charset=utf-8" - async def f(url: str) -> Response: - async with AsyncClient() as client: - return await client.put(url, json=data, headers=headers) # type: ignore + async def f(url: str, method: str) -> Response: + return await self.api_request(url, method, 2, headers=headers, json=data) return await self.__send_request_helper(path, "PUT", f, len(self.__hosts)) @@ -223,7 +259,7 @@ async def __send_request_helper( self, path: NormalisedURLPath, method: str, - http_function: Callable[[str], Awaitable[Response]], + http_function: Callable[[str, str], Awaitable[Response]], no_of_tries: int, retry_info_map: Optional[Dict[str, int]] = None, ) -> Any: @@ -253,7 +289,7 @@ async def __send_request_helper( ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER ) - response = await http_function(url) + response = await http_function(url, method) if ("SUPERTOKENS_ENV" in environ) and ( environ["SUPERTOKENS_ENV"] == "testing" ): diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index ac7885908..c51eca170 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -148,8 +148,8 @@ def __init__( framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, recipe_list: List[Callable[[AppInfo], RecipeModule]], - mode: Union[Literal["asgi", "wsgi"], None], - telemetry: Union[bool, None], + mode: Optional[Literal["asgi", "wsgi"]], + telemetry: Optional[bool], ): if not isinstance(app_info, InputAppInfo): # type: ignore raise ValueError("app_info must be an instance of InputAppInfo") @@ -215,12 +215,17 @@ def init( framework: Literal["fastapi", "flask", "django"], supertokens_config: SupertokensConfig, recipe_list: List[Callable[[AppInfo], RecipeModule]], - mode: Union[Literal["asgi", "wsgi"], None], - telemetry: Union[bool, None], + mode: Optional[Literal["asgi", "wsgi"]], + telemetry: Optional[bool], ): if Supertokens.__instance is None: Supertokens.__instance = Supertokens( - app_info, framework, supertokens_config, recipe_list, mode, telemetry + app_info, + framework, + supertokens_config, + recipe_list, + mode, + telemetry, ) PostSTInitCallbacks.run_post_init_callbacks() diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 1b8afd85b..3a14a5f7d 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -14,7 +14,6 @@ from __future__ import annotations -import asyncio import json import threading import warnings @@ -27,7 +26,6 @@ Any, Awaitable, Callable, - Coroutine, Dict, List, TypeVar, @@ -39,7 +37,6 @@ from httpx import HTTPStatusError, Response from tldextract import extract # type: ignore -from supertokens_python.async_to_sync_wrapper import check_event_loop from supertokens_python.framework.django.framework import DjangoFramework from supertokens_python.framework.fastapi.framework import FastapiFramework from supertokens_python.framework.flask.framework import FlaskFramework @@ -195,28 +192,6 @@ def find_first_occurrence_in_list( return None -def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]): - real_mode = None - try: - asyncio.get_running_loop() - real_mode = "asgi" - except RuntimeError: - real_mode = "wsgi" - - if mode != real_mode: - warnings.warn( - "Inconsistent mode detected, check if you are using the right asgi / wsgi mode", - category=RuntimeWarning, - ) - - if real_mode == "wsgi": - asyncio.run(func()) - else: - check_event_loop() - loop = asyncio.get_event_loop() - loop.create_task(func()) - - def frontend_has_interceptor(request: BaseRequest) -> bool: return get_rid_from_header(request) is not None