diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 49c16fee..7e1eedc7 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify +from supertokens_python.framework import BaseRequest from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe @@ -49,11 +50,21 @@ def default_st_init(): + def origin_func( + request: Optional[BaseRequest] = None, context: Dict[str, Any] = {} + ) -> str: + if request is None: + return "http://localhost:8080" + origin = request.get_header("origin") + if origin is not None: + return origin + return "http://localhost:8080" + init( app_info=InputAppInfo( app_name="SuperTokens", api_domain="http://api.supertokens.io", - website_domain="http://localhost:3000", + origin=origin_func, ), supertokens_config=SupertokensConfig(connection_uri="http://localhost:3567"), framework="flask", diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index a7dab1b7..f57c3057 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -4,6 +4,7 @@ from httpx import Response from supertokens_python.framework.flask.flask_request import FlaskRequest +from supertokens_python.types import RecipeUserId override_logs: List[Dict[str, Any]] = [] @@ -40,5 +41,7 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return "FlaskRequest" if isinstance(data, Response): return "Response" + if isinstance(data, RecipeUserId): + return data.get_as_string() return data diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 1d057412..2d60c50a 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,6 +1,11 @@ +from typing import Any from flask import Flask, request, jsonify +from supertokens_python.recipe.session.interfaces import TokenInfo +from supertokens_python.recipe.session.jwt import ( + parse_jwt_without_signature_verification, +) +from supertokens_python.types import RecipeUserId from utils import deserialize_validator -from supertokens_python import async_to_sync_wrapper from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session @@ -22,7 +27,7 @@ def create_new_session_without_request_response(): # type: ignore session_container = session.create_new_session_without_request_response( tenant_id, - user_id, + RecipeUserId(user_id), access_token_payload, session_data_in_database, disable_anti_csrf, @@ -76,29 +81,13 @@ def assert_claims(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - recipe_user_id=data["session"]["recipeUserId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) claim_validators = list(map(deserialize_validator, data["claimValidators"])) user_context = data.get("userContext", {}) try: - async_to_sync_wrapper.sync( - session_container.assert_claims(claim_validators, user_context) - ) + session_container.sync_assert_claims(claim_validators, user_context) return jsonify( { "status": "OK", @@ -135,28 +124,13 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - recipe_user_id=data["session"]["recipeUserId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) + access_token_payload_update = data["accessTokenPayloadUpdate"] user_context = data.get("userContext", {}) - async_to_sync_wrapper.sync( - session_container.merge_into_access_token_payload( - access_token_payload_update, user_context - ) + session_container.sync_merge_into_access_token_payload( + access_token_payload_update, user_context ) return jsonify( @@ -165,6 +139,7 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore "updatedSession": { "sessionHandle": session_container.get_handle(), "userId": session_container.get_user_id(), + "recipeUserId": session_container.get_recipe_user_id().get_as_string(), "tenantId": session_container.get_tenant_id(), "userDataInAccessToken": session_container.get_access_token_payload(), "accessToken": session_container.get_access_token(), @@ -183,3 +158,51 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore }, } ) + + +def convert_session_to_container(data: Any) -> Session: + jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) + jwt_payload = jwt_info.payload + + user_id = jwt_info.version == 2 and jwt_payload["userId"] or jwt_payload["sub"] + session_handle = jwt_payload["sessionHandle"] + + recipe_user_id = RecipeUserId(jwt_payload.get("rsub", user_id)) + anti_csrf_token = jwt_payload.get("antiCsrfToken") + tenant_id = jwt_info.version >= 4 and jwt_payload["tId"] or "public" + + return Session( + recipe_implementation=SessionRecipe.get_instance().recipe_implementation, + config=SessionRecipe.get_instance().config, + access_token=data["session"]["accessToken"], + front_token=data["session"]["frontToken"], + refresh_token=( + TokenInfo( + ( + data["session"]["refreshToken"] + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["token"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["expiry"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["createdTime"] + ), + ) + if "refreshToken" in data["session"] + else None + ), + anti_csrf_token=anti_csrf_token, + session_handle=session_handle, + user_id=user_id, + recipe_user_id=recipe_user_id, + user_data_in_access_token=jwt_payload, + req_res_info=None, # We don't have this information in the input + access_token_updated=False, + tenant_id=tenant_id, + )