diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d74623f1..5797d50c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.14.1] - 2023-05-23 + +### Changes + +- Added a new `get_request_from_user_context` function that can be used to read the original network request from the user context in overridden APIs and recipe functions + ## [0.14.0] - 2023-05-18 - Adds missing `check_database` boolean in `verify_session` diff --git a/setup.py b/setup.py index 0ba03bead..93e78d8d6 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="supertokens_python", - version="0.14.0", + version="0.14.1", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index 519e7411e..79035af08 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -12,8 +12,11 @@ # License for the specific language governing permissions and limitations # under the License. +from typing import Any, Callable, Dict, List, Optional, Union + from typing_extensions import Literal -from typing import Callable, List, Union + +from supertokens_python.framework.request import BaseRequest from . import supertokens from .recipe_module import RecipeModule @@ -39,3 +42,9 @@ def init( def get_all_cors_headers() -> List[str]: return supertokens.Supertokens.get_instance().get_all_cors_headers() + + +def get_request_from_user_context( + user_context: Optional[Dict[str, Any]], +) -> Optional[BaseRequest]: + return Supertokens.get_instance().get_request_from_user_context(user_context) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index f6ce6a6bd..ab3b7c89c 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -11,18 +11,18 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union from supertokens_python import Supertokens from supertokens_python.interfaces import ( CreateUserIdMappingOkResult, + DeleteUserIdMappingOkResult, + GetUserIdMappingOkResult, + UnknownMappingError, UnknownSupertokensUserIDError, + UpdateOrDeleteUserIdMappingInfoOkResult, UserIdMappingAlreadyExistsError, UserIDTypes, - UnknownMappingError, - GetUserIdMappingOkResult, - DeleteUserIdMappingOkResult, - UpdateOrDeleteUserIdMappingInfoOkResult, ) from supertokens_python.types import UsersResponse diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 4b4efbeff..0969f1889 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["2.21"] -VERSION = "0.14.0" +VERSION = "0.14.1" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 0a0e88227..5e9814ff4 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -552,3 +552,18 @@ async def handle_supertokens_error( ) return await recipe.handle_error(request, err, response) raise err + + def get_request_from_user_context( # pylint: disable=no-self-use + self, + user_context: Optional[Dict[str, Any]] = None, + ) -> Optional[BaseRequest]: + if user_context is None: + return None + + if "_default" not in user_context: + return None + + if not isinstance(user_context["_default"], dict): + return None + + return user_context.get("_default", {}).get("request") diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index 86830c34a..24a8ea476 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -11,19 +11,19 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import List, Union, Optional, Dict +from typing import Dict, List, Optional, Union from supertokens_python import Supertokens from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.interfaces import ( CreateUserIdMappingOkResult, + DeleteUserIdMappingOkResult, + GetUserIdMappingOkResult, + UnknownMappingError, UnknownSupertokensUserIDError, + UpdateOrDeleteUserIdMappingInfoOkResult, UserIdMappingAlreadyExistsError, UserIDTypes, - UnknownMappingError, - GetUserIdMappingOkResult, - DeleteUserIdMappingOkResult, - UpdateOrDeleteUserIdMappingInfoOkResult, ) from supertokens_python.types import UsersResponse diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 2fb9adbcc..dd2e36f15 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -13,8 +13,16 @@ # under the License. from typing import Any, Dict, List, Optional +from fastapi import FastAPI +from fastapi.testclient import TestClient from pytest import fixture, mark -from supertokens_python import InputAppInfo, SupertokensConfig, init + +from supertokens_python import ( + InputAppInfo, + SupertokensConfig, + get_request_from_user_context, + init, +) from supertokens_python.framework.fastapi import get_middleware from supertokens_python.recipe import emailpassword, session from supertokens_python.recipe.emailpassword.asyncio import sign_up @@ -28,9 +36,6 @@ RecipeInterface as SRecipeInterface, ) -from fastapi import FastAPI -from fastapi.testclient import TestClient - from .utils import clean_st, reset, setup_st, sign_in_request, start_st works = False @@ -277,3 +282,121 @@ async def create_new_session( create_new_session_context_works, ] ) + + +@mark.asyncio +async def test_get_request_from_user_context(driver_config_client: TestClient): + signin_api_context_works, signin_context_works, create_new_session_context_works = ( + False, + False, + False, + ) + + def apis_override_email_password(param: APIInterface): + og_sign_in_post = param.sign_in_post + + async def sign_in_post( + form_fields: List[FormField], + api_options: APIOptions, + user_context: Dict[str, Any], + ): + req = get_request_from_user_context(user_context) + if req: + assert req.method() == "POST" + assert req.get_path() == "/auth/signin" + nonlocal signin_api_context_works + signin_api_context_works = True + + return await og_sign_in_post(form_fields, api_options, user_context) + + param.sign_in_post = sign_in_post + return param + + def functions_override_email_password(param: RecipeInterface): + og_sign_in = param.sign_in + + async def sign_in(email: str, password: str, user_context: Dict[str, Any]): + req = get_request_from_user_context(user_context) + if req: + assert req.method() == "POST" + assert req.get_path() == "/auth/signin" + nonlocal signin_context_works + signin_context_works = True + + orginal_request = req + user_context["_default"]["request"] = None + + newReq = get_request_from_user_context(user_context) + assert newReq is None + + user_context["_default"]["request"] = orginal_request + + return await og_sign_in(email, password, user_context) + + param.sign_in = sign_in + return param + + def functions_override_session(param: SRecipeInterface): + og_create_new_session = param.create_new_session + + async def create_new_session( + user_id: str, + access_token_payload: Optional[Dict[str, Any]], + session_data_in_database: Optional[Dict[str, Any]], + disable_anti_csrf: Optional[bool], + user_context: Dict[str, Any], + ): + req = get_request_from_user_context(user_context) + if req: + assert req.method() == "POST" + assert req.get_path() == "/auth/signin" + nonlocal create_new_session_context_works + create_new_session_context_works = True + + response = await og_create_new_session( + user_id, + access_token_payload, + session_data_in_database, + disable_anti_csrf, + user_context, + ) + return response + + param.create_new_session = create_new_session + return param + + init( + supertokens_config=SupertokensConfig("http://localhost:3567"), + app_info=InputAppInfo( + app_name="SuperTokens Demo", + api_domain="http://api.supertokens.io", + website_domain="http://supertokens.io", + ), + framework="fastapi", + recipe_list=[ + emailpassword.init( + override=emailpassword.InputOverrideConfig( + apis=apis_override_email_password, + functions=functions_override_email_password, + ) + ), + session.init( + override=session.InputOverrideConfig( + functions=functions_override_session + ) + ), + ], + ) + start_st() + + await sign_up("random@gmail.com", "validpass123", {"manualCall": True}) + res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") + + assert res.status_code == 200 + assert all( + [ + signin_api_context_works, + signin_context_works, + create_new_session_context_works, + ] + )