Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use nest-asyncio when configured with env var #451

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.2] - 2023-09-20

- Allow use of [nest-asyncio](https://pypi.org/project/nest-asyncio/) when env var `SUPERTOKENS_NEST_ASYNCIO=1`.
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved
- Retry Querier request on `AsyncLibraryNotFoundError`

## [0.16.1] - 2023-09-19
- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.16.1",
version="0.16.2",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
13 changes: 9 additions & 4 deletions supertokens_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.

from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

from typing_extensions import Literal

Expand All @@ -32,11 +32,16 @@ def init(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None] = None,
telemetry: Union[bool, None] = None,
mode: Optional[Literal["asgi", "wsgi"]] = None,
telemetry: Optional[bool] = None,
):
return Supertokens.init(
app_info, framework, supertokens_config, recipe_list, mode, telemetry
app_info,
framework,
supertokens_config,
recipe_list,
mode,
telemetry,
)


Expand Down
27 changes: 22 additions & 5 deletions supertokens_python/async_to_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,37 @@

import asyncio
from typing import Any, Coroutine, TypeVar
from os import getenv

_T = TypeVar("_T")


def check_event_loop():
def is_nest_asyncio_enabled():
try:
asyncio.get_event_loop()
except RuntimeError as ex:
import nest_asyncio as _ # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this import here?


return getenv("SUPERTOKENS_NEST_ASYNCIO", "") == "1"
except Exception:
return False

rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that there is at least one thing in cicd where we run the python sdk (for example in e2e test with flask), where nest_asyncio is not a dependency and then we make sure it all works

def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
try:
return asyncio.get_event_loop()
except Exception as ex:
if "There is no current event loop in thread" in str(ex):
loop = asyncio.new_event_loop()

if is_nest_asyncio_enabled():
import nest_asyncio # type: ignore

nest_asyncio.apply(loop) # type: ignore

asyncio.set_event_loop(loop)
return loop
raise ex


def sync(co: Coroutine[Any, Any, _T]) -> _T:
check_event_loop()
loop = asyncio.get_event_loop()
loop = create_or_get_event_loop()
return loop.run_until_complete(co)
2 changes: 1 addition & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.16.1"
VERSION = "0.16.2"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand Down
86 changes: 61 additions & 25 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from .exceptions import raise_general_exception
from .process_state import AllowedProcessStates, ProcessState
from .utils import find_max_version, is_4xx_error, is_5xx_error
from sniffio import AsyncLibraryNotFoundError
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop


class Querier:
Expand Down Expand Up @@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
raise_general_exception("calling testing function in non testing env")
return Querier.__hosts_alive_for_testing

async def api_request(
self,
url: str,
method: str,
attempts_remaining: int,
*args: Any,
**kwargs: Any,
) -> Response:
if attempts_remaining == 0:
raise_general_exception("Retry request failed")

try:
async with AsyncClient() as client:
if method == "GET":
return await client.get(url, *args, **kwargs) # type: ignore
if method == "POST":
return await client.post(url, *args, **kwargs) # type: ignore
if method == "PUT":
return await client.put(url, *args, **kwargs) # type: ignore
if method == "DELETE":
return await client.delete(url, *args, **kwargs) # type: ignore
raise Exception("Shouldn't come here")
except AsyncLibraryNotFoundError:
# Retry
loop = create_or_get_event_loop()
return loop.run_until_complete(
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
)

async def get_api_version(self):
if Querier.api_version is not None:
return Querier.api_version
Expand All @@ -79,12 +110,11 @@ async def get_api_version(self):
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
)

async def f(url: str) -> Response:
async def f(url: str, method: str) -> Response:
headers = {}
if Querier.__api_key is not None:
headers = {API_KEY_HEADER: Querier.__api_key}
async with AsyncClient() as client:
return await client.get(url, headers=headers) # type:ignore
return await self.api_request(url, method, 2, headers=headers)

response = await self.__send_request_helper(
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
Expand Down Expand Up @@ -134,13 +164,14 @@ async def send_get_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.get( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "GET", f, len(self.__hosts))

Expand All @@ -163,9 +194,14 @@ async def send_post_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.post(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
json=data,
)

return await self.__send_request_helper(path, "POST", f, len(self.__hosts))

Expand All @@ -175,13 +211,14 @@ async def send_delete_request(
if params is None:
params = {}

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.delete( # type:ignore
url,
params=params,
headers=await self.__get_headers_with_api_version(path),
)
async def f(url: str, method: str) -> Response:
return await self.api_request(
url,
method,
2,
headers=await self.__get_headers_with_api_version(path),
params=params,
)

return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))

Expand All @@ -194,9 +231,8 @@ async def send_put_request(
headers = await self.__get_headers_with_api_version(path)
headers["content-type"] = "application/json; charset=utf-8"

async def f(url: str) -> Response:
async with AsyncClient() as client:
return await client.put(url, json=data, headers=headers) # type: ignore
async def f(url: str, method: str) -> Response:
return await self.api_request(url, method, 2, headers=headers, json=data)

return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))

Expand All @@ -223,7 +259,7 @@ async def __send_request_helper(
self,
path: NormalisedURLPath,
method: str,
http_function: Callable[[str], Awaitable[Response]],
http_function: Callable[[str, str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
Expand Down Expand Up @@ -253,7 +289,7 @@ async def __send_request_helper(
ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
response = await http_function(url)
response = await http_function(url, method)
if ("SUPERTOKENS_ENV" in environ) and (
environ["SUPERTOKENS_ENV"] == "testing"
):
Expand Down
15 changes: 10 additions & 5 deletions supertokens_python/supertokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def __init__(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
telemetry: Union[bool, None],
mode: Optional[Literal["asgi", "wsgi"]],
telemetry: Optional[bool],
):
if not isinstance(app_info, InputAppInfo): # type: ignore
raise ValueError("app_info must be an instance of InputAppInfo")
Expand Down Expand Up @@ -215,12 +215,17 @@ def init(
framework: Literal["fastapi", "flask", "django"],
supertokens_config: SupertokensConfig,
recipe_list: List[Callable[[AppInfo], RecipeModule]],
mode: Union[Literal["asgi", "wsgi"], None],
telemetry: Union[bool, None],
mode: Optional[Literal["asgi", "wsgi"]],
telemetry: Optional[bool],
):
if Supertokens.__instance is None:
Supertokens.__instance = Supertokens(
app_info, framework, supertokens_config, recipe_list, mode, telemetry
app_info,
framework,
supertokens_config,
recipe_list,
mode,
telemetry,
)
PostSTInitCallbacks.run_post_init_callbacks()

Expand Down
25 changes: 0 additions & 25 deletions supertokens_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import asyncio
import json
import threading
import warnings
Expand All @@ -27,7 +26,6 @@
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
TypeVar,
Expand All @@ -39,7 +37,6 @@
from httpx import HTTPStatusError, Response
from tldextract import extract # type: ignore

from supertokens_python.async_to_sync_wrapper import check_event_loop
from supertokens_python.framework.django.framework import DjangoFramework
from supertokens_python.framework.fastapi.framework import FastapiFramework
from supertokens_python.framework.flask.framework import FlaskFramework
Expand Down Expand Up @@ -195,28 +192,6 @@ def find_first_occurrence_in_list(
return None


def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
real_mode = None
try:
asyncio.get_running_loop()
real_mode = "asgi"
except RuntimeError:
real_mode = "wsgi"

if mode != real_mode:
warnings.warn(
"Inconsistent mode detected, check if you are using the right asgi / wsgi mode",
category=RuntimeWarning,
)

if real_mode == "wsgi":
asyncio.run(func())
else:
check_event_loop()
loop = asyncio.get_event_loop()
loop.create_task(func())
rishabhpoddar marked this conversation as resolved.
Show resolved Hide resolved


def frontend_has_interceptor(request: BaseRequest) -> bool:
return get_rid_from_header(request) is not None

Expand Down