From 8020175e5947a2fe07bc07be2d3ae08d080ce792 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 21 Jul 2024 15:22:15 +0530 Subject: [PATCH 001/126] few modifications --- supertokens_python/constants.py | 1 + supertokens_python/normalised_url_domain.py | 17 +++-------- supertokens_python/normalised_url_path.py | 34 +++++---------------- supertokens_python/post_init_callbacks.py | 9 +++--- 4 files changed, 18 insertions(+), 43 deletions(-) diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index f7e60529a..387feab5c 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -11,6 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + from __future__ import annotations SUPPORTED_CDI_VERSIONS = ["3.0"] diff --git a/supertokens_python/normalised_url_domain.py b/supertokens_python/normalised_url_domain.py index ffd57a738..1309f234b 100644 --- a/supertokens_python/normalised_url_domain.py +++ b/supertokens_python/normalised_url_domain.py @@ -11,31 +11,29 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import annotations +from __future__ import annotations from typing import TYPE_CHECKING from urllib.parse import urlparse - from .utils import is_an_ip_address +from .exceptions import raise_general_exception if TYPE_CHECKING: pass -from .exceptions import raise_general_exception class NormalisedURLDomain: def __init__(self, url: str): - self.__value = normalise_domain_path_or_throw_error(url) + self.__value = normalise_url_domain_or_throw_error(url) def get_as_string_dangerous(self): return self.__value -def normalise_domain_path_or_throw_error( +def normalise_url_domain_or_throw_error( input_str: str, ignore_protocol: bool = False ) -> str: input_str = input_str.strip().lower() - try: if ( (not input_str.startswith("http://")) @@ -44,7 +42,6 @@ def normalise_domain_path_or_throw_error( ): raise Exception("converting to proper URL") url_obj = urlparse(input_str) - if ignore_protocol: if url_obj.hostname is None: raise Exception("Should never come here") @@ -56,17 +53,13 @@ def normalise_domain_path_or_throw_error( input_str = "https://" + url_obj.netloc else: input_str = url_obj.scheme + "://" + url_obj.netloc - return input_str except Exception: pass - if input_str.startswith("/"): raise_general_exception("Please provide a valid domain name") - if input_str.startswith("."): input_str = input_str[1:] - if ( ("." in input_str or input_str.startswith("localhost")) and (not input_str.startswith("http://")) @@ -75,7 +68,7 @@ def normalise_domain_path_or_throw_error( input_str = "https://" + input_str try: urlparse(input_str) - return normalise_domain_path_or_throw_error(input_str, True) + return normalise_url_domain_or_throw_error(input_str, True) except Exception: pass raise_general_exception("Please provide a valid domain name") diff --git a/supertokens_python/normalised_url_path.py b/supertokens_python/normalised_url_path.py index a99142415..11cf9c5c5 100644 --- a/supertokens_python/normalised_url_path.py +++ b/supertokens_python/normalised_url_path.py @@ -13,13 +13,12 @@ # under the License. from __future__ import annotations - from typing import TYPE_CHECKING from urllib.parse import urlparse +from .exceptions import raise_general_exception if TYPE_CHECKING: pass -from .exceptions import raise_general_exception class NormalisedURLPath: @@ -40,40 +39,30 @@ def equals(self, other: NormalisedURLPath) -> bool: def is_a_recipe_path(self) -> bool: parts = self.__value.split("/") - return (len(parts) > 1 and parts[1] == "recipe") or ( - len(parts) > 2 and parts[2] == "recipe" - ) + return parts[1] == "recipe" or (len(parts) > 2 and parts[2] == "recipe") def normalise_url_path_or_throw_error(input_str: str) -> str: input_str = input_str.strip().lower() - try: - if (not input_str.startswith("http://")) and ( - not input_str.startswith("https://") - ): + if not input_str.startswith("http://") and not input_str.startswith("https://"): raise Exception("converting to proper URL") url_obj = urlparse(input_str) input_str = url_obj.path - if input_str.endswith("/"): return input_str[:-1] - return input_str except Exception: pass - if ( (domain_given(input_str) or input_str.startswith("localhost")) - and (not input_str.startswith("http://")) - and (not input_str.startswith("https://")) + and not input_str.startswith("http://") + and not input_str.startswith("https://") ): input_str = "http://" + input_str return normalise_url_path_or_throw_error(input_str) - if not input_str.startswith("/"): input_str = "/" + input_str - try: urlparse("http://example.com" + input_str) return normalise_url_path_or_throw_error("http://example.com" + input_str) @@ -82,23 +71,16 @@ def normalise_url_path_or_throw_error(input_str: str) -> str: def domain_given(input_str: str) -> bool: - if ("." not in input_str) or (input_str.startswith("/")): + if "." not in input_str or input_str.startswith("/"): return False - try: url = urlparse(input_str) - if url.hostname is None: - raise Exception("Should never come here") - return url.hostname.find(".") != -1 + return url.hostname is not None and "." in url.hostname except Exception: pass - try: url = urlparse("http://" + input_str) - if url.hostname is None: - raise Exception("Should never come here") - return url.hostname.find(".") != -1 + return url.hostname is not None and "." in url.hostname except Exception: pass - return False diff --git a/supertokens_python/post_init_callbacks.py b/supertokens_python/post_init_callbacks.py index 227acb78e..982e1b741 100644 --- a/supertokens_python/post_init_callbacks.py +++ b/supertokens_python/post_init_callbacks.py @@ -18,15 +18,14 @@ class PostSTInitCallbacks: """Callbacks that are called after the SuperTokens instance is initialized.""" - callbacks: List[Callable[[], None]] = [] + post_init_callbacks: List[Callable[[], None]] = [] @staticmethod def add_post_init_callback(cb: Callable[[], None]) -> None: - PostSTInitCallbacks.callbacks.append(cb) + PostSTInitCallbacks.post_init_callbacks.append(cb) @staticmethod def run_post_init_callbacks() -> None: - for cb in PostSTInitCallbacks.callbacks: + for cb in PostSTInitCallbacks.post_init_callbacks: cb() - - PostSTInitCallbacks.callbacks = [] + PostSTInitCallbacks.post_init_callbacks = [] From df89713b29c259638be3ad63861180a99b4713ab Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 21 Jul 2024 16:06:11 +0530 Subject: [PATCH 002/126] adds new type of user --- supertokens_python/types.py | 115 +++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 1c7b2f799..27691329e 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -11,16 +11,127 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + from abc import ABC, abstractmethod from typing import Any, Awaitable, Dict, List, TypeVar, Union +from phonenumbers import format_number, parse # type: ignore +import phonenumbers # type: ignore _T = TypeVar("_T") +class RecipeUserId: + def __init__(self, recipe_user_id: str): + self.recipe_user_id = recipe_user_id + + def get_as_string(self) -> str: + return self.recipe_user_id + + class ThirdPartyInfo: - def __init__(self, third_party_user_id: str, third_party_id: str): - self.user_id = third_party_user_id + def __init__(self, third_party_id: str, third_party_user_id: str): self.id = third_party_id + self.user_id = third_party_user_id + + +class LoginMethod: + def __init__( + self, + recipe_id: str, + recipe_user_id: str, + tenant_ids: List[str], + email: Union[str, None], + phone_number: Union[str, None], + third_party: Union[ThirdPartyInfo, None], + time_joined: int, + verified: bool, + ): + self.recipe_id = recipe_id + self.recipe_user_id = RecipeUserId(recipe_user_id) + self.tenant_ids = tenant_ids + self.email = email + self.phone_number = phone_number + self.third_party = third_party + self.time_joined = time_joined + self.verified = verified + + def has_same_email_as(self, email: Union[str, None]) -> bool: + if email is None: + return False + return ( + self.email is not None + and self.email.lower().strip() == email.lower().strip() + ) + + def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: + if phone_number is None: + return False + + cleaned_phone = phone_number.strip() + try: + cleaned_phone = format_number( + parse(phone_number, None), phonenumbers.PhoneNumberFormat.E164 + ) + except Exception: + pass # here we just use the stripped version + + return self.phone_number is not None and self.phone_number == cleaned_phone + + def has_same_third_party_info_as( + self, third_party: Union[ThirdPartyInfo, None] + ) -> bool: + if third_party is None or self.third_party is None: + return False + return ( + self.third_party.id.strip() == third_party.id.strip() + and self.third_party.user_id.strip() == third_party.user_id.strip() + ) + + def to_json(self) -> Dict[str, Any]: + return { + "recipeId": self.recipe_id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "tenantIds": self.tenant_ids, + "email": self.email, + "phoneNumber": self.phone_number, + "thirdParty": self.third_party.__dict__ if self.third_party else None, + "timeJoined": self.time_joined, + "verified": self.verified, + } + + +class AccountLinkingUser: + def __init__( + self, + user_id: str, + is_primary_user: bool, + tenant_ids: List[str], + emails: List[str], + phone_numbers: List[str], + third_party: List[ThirdPartyInfo], + login_methods: List[LoginMethod], + time_joined: int, + ): + self.id = user_id + self.is_primary_user = is_primary_user + self.tenant_ids = tenant_ids + self.emails = emails + self.phone_numbers = phone_numbers + self.third_party = third_party + self.login_methods = login_methods + self.time_joined = time_joined + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "isPrimaryUser": self.is_primary_user, + "tenantIds": self.tenant_ids, + "emails": self.emails, + "phoneNumbers": self.phone_numbers, + "thirdParty": self.third_party, + "loginMethods": [lm.to_json() for lm in self.login_methods], + "timeJoined": self.time_joined, + } class User: From 48a1b651a68bc42a29ba365a52a3e6d876be81c2 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 22 Jul 2024 23:12:43 +0530 Subject: [PATCH 003/126] adds more changes --- .../recipe/session/interfaces.py | 20 ++++- .../recipe/session/session_class.py | 6 ++ supertokens_python/utils.py | 82 ++++++++++++++++++- 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 31802fad5..63bce71c6 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -28,7 +28,12 @@ from typing_extensions import TypedDict from supertokens_python.async_to_sync_wrapper import sync -from supertokens_python.types import APIResponse, GeneralErrorResponse, MaybeAwaitable +from supertokens_python.types import ( + APIResponse, + GeneralErrorResponse, + MaybeAwaitable, + RecipeUserId, +) from ...utils import resolve from .exceptions import ClaimValidationError @@ -416,6 +421,9 @@ def __init__( self.req_res_info: Optional[ReqResInfo] = req_res_info self.access_token_updated = access_token_updated self.tenant_id = tenant_id + self.recipe_user_id = RecipeUserId( + user_id + ) # TODO: change me to be based on input arg. self.response_mutators: List[ResponseMutator] = [] @@ -460,6 +468,12 @@ async def merge_into_access_token_payload( def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass + @abstractmethod + def get_recipe_user_id( + self, user_context: Optional[Dict[str, Any]] = None + ) -> RecipeUserId: + pass + @abstractmethod def get_tenant_id(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass @@ -601,11 +615,11 @@ def sync_remove_claim( def sync_attach_to_request_response( self, request: BaseRequest, - token_transfer: TokenTransferMethod, + transfer_method: TokenTransferMethod, user_context: Dict[str, Any], ) -> None: return sync( - self.attach_to_request_response(request, token_transfer, user_context) + self.attach_to_request_response(request, transfer_method, user_context) ) # This is there so that we can do session["..."] to access some of the members of this class diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index e55a6c5da..dba63b3c0 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -17,6 +17,7 @@ raise_invalid_claims_exception, raise_unauthorised_exception, ) +from supertokens_python.types import RecipeUserId from .jwt import parse_jwt_without_signature_verification from .utils import TokenTransferMethod @@ -139,6 +140,11 @@ async def update_session_data_in_database( def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str: return self.user_id + def get_recipe_user_id( + self, user_context: Union[Dict[str, Any], None] = None + ) -> RecipeUserId: + return self.recipe_user_id + def get_tenant_id(self, user_context: Union[Dict[str, Any], None] = None) -> str: return self.tenant_id diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index aa202a0ec..7e3c43e44 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -44,9 +44,13 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message -from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER +if TYPE_CHECKING: + from supertokens_python.recipe.session import SessionContainer + +from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER, FDI_KEY_HEADER from .exceptions import raise_general_exception from .types import MaybeAwaitable +from supertokens_python.types import AccountLinkingUser _T = TypeVar("_T") @@ -286,7 +290,81 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str: "Please make sure that the apiDomain and websiteDomain have correct values" ) - return parsed_url.domain + "." + parsed_url.suffix # type: ignore + return parsed_url.domain + "." + parsed_url.suffix + + +def get_backwards_compatible_user_info( + req: BaseRequest, + user_info: AccountLinkingUser, + session_container: SessionContainer, + created_new_recipe_user: Union[bool, None], + user_context: Dict[str, Any], +) -> Dict[str, Any]: + resp: Dict[str, Any] = {} + # (>= 1.18 && < 2.0) || >= 3.0: This is because before 1.18, and between 2 and 3, FDI does not + # support account linking. + if ( + has_greater_than_equal_to_fdi(req, "1.18") + and not has_greater_than_equal_to_fdi(req, "2.0") + ) or has_greater_than_equal_to_fdi(req, "3.0"): + resp = {"user": user_info.to_json()} + + if created_new_recipe_user is not None: + resp["createdNewRecipeUser"] = created_new_recipe_user + return resp + + login_method = next( + ( + lm + for lm in user_info.login_methods + if lm.recipe_user_id.get_as_string() + == session_container.get_recipe_user_id(user_context).get_as_string() + ), + None, + ) + + if login_method is None: + # we pick the oldest login method here for the user. + # this can happen in case the user is implementing something like + # MFA where the session remains the same during the second factor as well. + login_method = min(user_info.login_methods, key=lambda lm: lm.time_joined) + + user_obj: Dict[str, Any] = { + "id": user_info.id, # we purposely use this instead of the loginmethod's recipeUserId because if the oldest login method is deleted, then this userID should remain the same. + "timeJoined": login_method.time_joined, + } + if login_method.third_party: + user_obj["thirdParty"] = login_method.third_party + if login_method.email: + user_obj["email"] = login_method.email + if login_method.phone_number: + user_obj["phoneNumber"] = login_method.phone_number + + resp = {"user": user_obj} + + if created_new_recipe_user is not None: + resp["createdNewUser"] = created_new_recipe_user + + return resp + + +def get_latest_fdi_version_from_fdi_list(fdi_header_value: str) -> str: + versions = fdi_header_value.split(",") + max_version_str = versions[0] + for version in versions[1:]: + max_version_str = _get_max_version(max_version_str, version) + return max_version_str + + +def has_greater_than_equal_to_fdi(req: BaseRequest, version: str) -> bool: + request_fdi = req.get_header(FDI_KEY_HEADER) + if request_fdi is None: + # By default we assume they want to use the latest FDI, this also helps with tests + return True + request_fdi = get_latest_fdi_version_from_fdi_list(request_fdi) + if request_fdi == version or _get_max_version(version, request_fdi) != version: + return True + return False class RWMutex: From 018cbc491ce9e3539f5e7566d4849ca24f948384 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 23 Jul 2024 00:21:58 +0530 Subject: [PATCH 004/126] adds process state file changes --- supertokens_python/process_state.py | 27 ++++++++++++++----- supertokens_python/querier.py | 6 ++--- .../recipe/session/session_functions.py | 6 ++--- tests/test_session.py | 18 ++++++------- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/supertokens_python/process_state.py b/supertokens_python/process_state.py index e50fb95b2..d7decdd2e 100644 --- a/supertokens_python/process_state.py +++ b/supertokens_python/process_state.py @@ -12,22 +12,27 @@ # License for the specific language governing permissions and limitations # under the License. from os import environ -from typing import List +from typing import List, Optional from enum import Enum -class AllowedProcessStates(Enum): +class PROCESS_STATE(Enum): CALLING_SERVICE_IN_VERIFY = 1 - CALLING_SERVICE_IN_GET_HANDSHAKE_INFO = 2 - CALLING_SERVICE_IN_GET_API_VERSION = 3 - CALLING_SERVICE_IN_REQUEST_HELPER = 4 + CALLING_SERVICE_IN_GET_API_VERSION = 2 + CALLING_SERVICE_IN_REQUEST_HELPER = 3 + MULTI_JWKS_VALIDATION = 4 + IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS = 5 + IS_SIGN_UP_ALLOWED_CALLED = 6 + IS_SIGN_IN_ALLOWED_CALLED = 7 + IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED = 8 + ADDING_NO_CACHE_HEADER_IN_FETCH = 9 class ProcessState: __instance = None def __init__(self): - self.history: List[AllowedProcessStates] = [] + self.history: List[PROCESS_STATE] = [] @staticmethod def get_instance(): @@ -35,9 +40,17 @@ def get_instance(): ProcessState.__instance = ProcessState() return ProcessState.__instance - def add_state(self, state: AllowedProcessStates): + def add_state(self, state: PROCESS_STATE): if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): self.history.append(state) def reset(self): self.history = [] + + def get_event_by_last_event_by_name( + self, state: PROCESS_STATE + ) -> Optional[PROCESS_STATE]: + for event in reversed(self.history): + if event == state: + return event + return None diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index c1de63dd2..4232dd8df 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -35,7 +35,7 @@ from typing import List, Set, Union -from .process_state import AllowedProcessStates, ProcessState +from .process_state import PROCESS_STATE, ProcessState from .utils import find_max_version, is_4xx_error, is_5xx_error from sniffio import AsyncLibraryNotFoundError from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop @@ -128,7 +128,7 @@ async def get_api_version(self): return Querier.api_version ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION + PROCESS_STATE.CALLING_SERVICE_IN_GET_API_VERSION ) async def f(url: str, method: str) -> Response: @@ -463,7 +463,7 @@ async def __send_request_helper( retry_info_map[url] = max_retries ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER + PROCESS_STATE.CALLING_SERVICE_IN_REQUEST_HELPER ) response = await http_function(url, method) if ("SUPERTOKENS_ENV" in environ) and ( diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index b2e214721..5037ad45f 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -26,7 +26,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.process_state import AllowedProcessStates, ProcessState +from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe.session.interfaces import TokenInfo from .exceptions import ( @@ -270,9 +270,7 @@ async def get_session( ) ) - ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - ) + ProcessState.get_instance().add_state(PROCESS_STATE.CALLING_SERVICE_IN_VERIFY) data = { "accessToken": parsed_access_token.raw_token_string, diff --git a/tests/test_session.py b/tests/test_session.py index 38cc901da..f42446095 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -27,7 +27,7 @@ from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.framework import BaseRequest from supertokens_python.framework.fastapi.fastapi_middleware import get_middleware -from supertokens_python.process_state import AllowedProcessStates, ProcessState +from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe import session from supertokens_python.recipe.session import InputOverrideConfig, SessionRecipe from supertokens_python.recipe.session.asyncio import ( @@ -119,7 +119,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): s.recipe_implementation, access_token, response.antiCsrfToken, True, False, None ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -151,8 +151,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - in ProcessState.get_instance().history + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY in ProcessState.get_instance().history ) assert response3.session is not None @@ -173,7 +172,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): None, ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -742,7 +741,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session1.refresh_token is not None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -756,7 +755,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session2.refresh_token is None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -770,7 +769,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session3.refresh_token is not None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -784,8 +783,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session4.refresh_token is None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - in ProcessState.get_instance().history + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY in ProcessState.get_instance().history ) # Core got called this time From f88a855c7da9aad86d17c5d30bf0e973065ee4f9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 23 Jul 2024 01:46:54 +0530 Subject: [PATCH 005/126] removes get_users_oldest/newest first from supertokens.py --- supertokens_python/asyncio/__init__.py | 22 +- .../dashboard/api/userdetails/user_delete.py | 4 +- .../dashboard/api/userdetails/user_get.py | 7 +- .../dashboard/api/userdetails/user_put.py | 19 +- .../recipe/dashboard/api/users_get.py | 45 ++-- .../recipe/dashboard/interfaces.py | 4 +- supertokens_python/recipe/dashboard/utils.py | 216 +++++++----------- supertokens_python/supertokens.py | 84 +------ supertokens_python/syncio/__init__.py | 19 +- supertokens_python/types.py | 6 +- .../test_supertokens_functions.py | 14 +- tests/test_pagination.py | 6 +- 12 files changed, 169 insertions(+), 277 deletions(-) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index 59796af43..b51b9f05f 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -35,15 +35,7 @@ async def get_users_oldest_first( query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, ) -> UsersResponse: - return await Supertokens.get_instance().get_users( - tenant_id, - "ASC", - limit, - pagination_token, - include_recipe_ids, - query, - user_context, - ) + raise NotImplementedError("This function is not implemented") async def get_users_newest_first( @@ -54,15 +46,7 @@ async def get_users_newest_first( query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, ) -> UsersResponse: - return await Supertokens.get_instance().get_users( - tenant_id, - "DESC", - limit, - pagination_token, - include_recipe_ids, - query, - user_context, - ) + raise NotImplementedError("This function is not implemented") async def get_user_count( @@ -78,7 +62,7 @@ async def get_user_count( async def delete_user( user_id: str, user_context: Optional[Dict[str, Any]] = None ) -> None: - return await Supertokens.get_instance().delete_user(user_id, user_context) + raise NotImplementedError("This function is not implemented") async def create_user_id_mapping( diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py index 44e59e427..0e781450b 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py @@ -2,7 +2,7 @@ from ...interfaces import APIInterface, APIOptions, UserDeleteAPIResponse from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python import Supertokens +from supertokens_python.asyncio import delete_user async def handle_user_delete( @@ -16,6 +16,6 @@ async def handle_user_delete( if user_id is None: raise_bad_input_exception("Missing required parameter 'userId'") - await Supertokens.get_instance().delete_user(user_id, _user_context) + await delete_user(user_id, _user_context) return UserDeleteAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py index 0d8350b0b..f4559f48d 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py @@ -4,6 +4,7 @@ from supertokens_python.recipe.dashboard.utils import get_user_for_recipe_id from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.usermetadata.asyncio import get_user_metadata +from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, @@ -40,8 +41,10 @@ async def handle_user_get( if not is_recipe_initialised(recipe_id): return UserGetAPIRecipeNotInitialisedError() - user_response = await get_user_for_recipe_id(user_id, recipe_id) - if user_response is None: + user_response = await get_user_for_recipe_id( + RecipeUserId(user_id), recipe_id, _user_context + ) + if user_response.user is None: return UserGetAPINoUserFoundError() user = user_response.user diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index 25b4b6360..4f72e8cc7 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -35,6 +35,7 @@ ) from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.usermetadata.asyncio import update_user_metadata +from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, @@ -219,9 +220,11 @@ async def handle_user_put( "Required parameter 'phone' is missing or has an invalid type" ) - user_response = await get_user_for_recipe_id(user_id, recipe_id) + user_response = await get_user_for_recipe_id( + RecipeUserId(user_id), recipe_id, user_context + ) - if user_response is None: + if user_response.user is None: raise Exception("Should never come here") first_name = first_name.strip() @@ -251,7 +254,11 @@ async def handle_user_put( if email != "": email_update_response = await update_email_for_recipe_id( - user_response.recipe, user_id, email, tenant_id, user_context + user_response.recipe or "passwordless", + user_id, + email, + tenant_id, + user_context, ) if not isinstance(email_update_response, UserPutAPIOkResponse): @@ -259,7 +266,11 @@ async def handle_user_put( if phone != "": phone_update_response = await update_phone_for_recipe_id( - user_response.recipe, user_id, phone, tenant_id, user_context + user_response.recipe or "passwordless", + user_id, + phone, + tenant_id, + user_context, ) if not isinstance(phone_update_response, UserPutAPIOkResponse): diff --git a/supertokens_python/recipe/dashboard/api/users_get.py b/supertokens_python/recipe/dashboard/api/users_get.py index 949e569da..7a3ff8452 100644 --- a/supertokens_python/recipe/dashboard/api/users_get.py +++ b/supertokens_python/recipe/dashboard/api/users_get.py @@ -14,11 +14,9 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any, Awaitable, List, Dict +from typing import TYPE_CHECKING, Any, Awaitable, List, Dict, Union from typing_extensions import Literal -from supertokens_python.supertokens import Supertokens - from ...usermetadata import UserMetadataRecipe from ...usermetadata.asyncio import get_user_metadata from ..interfaces import DashboardUsersGetResponse @@ -29,9 +27,15 @@ APIOptions, APIInterface, ) - from supertokens_python.types import APIResponse + from supertokens_python.types import APIResponse, AccountLinkingUser from supertokens_python.exceptions import GeneralError, raise_bad_input_exception +from supertokens_python.asyncio import get_users_newest_first, get_users_oldest_first + + +class MockUsersResponse: + users: List[AccountLinkingUser] = [] + next_pagination_token: Union[str, None] = None async def handle_users_get_api( @@ -50,22 +54,23 @@ async def handle_users_get_api( "timeJoinedOrder", "DESC" ) if time_joined_order not in ["ASC", "DESC"]: - raise_bad_input_exception("Invalid value recieved for 'timeJoinedOrder'") + raise_bad_input_exception("Invalid value received for 'timeJoinedOrder'") pagination_token = api_options.request.get_query_param("paginationToken") + query = get_search_params_from_url(api_options.request.get_original_url()) - users_response = await Supertokens.get_instance().get_users( + users_response = await ( + get_users_newest_first + if time_joined_order == "DESC" + else get_users_oldest_first + )( tenant_id, - time_joined_order=time_joined_order, limit=int(limit), pagination_token=pagination_token, - include_recipe_ids=None, - query=api_options.request.get_query_params(), + query=query, user_context=user_context, ) - # user metadata bulk fetch with batches: - try: UserMetadataRecipe.get_instance() except GeneralError: @@ -80,15 +85,13 @@ async def handle_users_get_api( async def get_user_metadata_and_update_user(user_idx: int) -> None: user = users_response.users[user_idx] - user_metadata = await get_user_metadata(user.user_id, user_context) + user_metadata = await get_user_metadata(user.id, user_context) first_name = user_metadata.metadata.get("first_name") last_name = user_metadata.metadata.get("last_name") - # None becomes null which is acceptable for the dashboard. users_with_metadata[user_idx].first_name = first_name users_with_metadata[user_idx].last_name = last_name - # Batch calls to get user metadata: for i, _ in enumerate(users_response.users): metadata_fetch_awaitables.append(get_user_metadata_and_update_user(i)) @@ -108,10 +111,22 @@ async def get_user_metadata_and_update_user(user_idx: int) -> None: ) ] await asyncio.gather(*promises_to_call) - promise_arr_start_position += batch_size return DashboardUsersGetResponse( users_with_metadata, users_response.next_pagination_token, ) + + +def get_search_params_from_url(url: str) -> Dict[str, str]: + from urllib.parse import urlparse, parse_qs + + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + search_query = { + key: value[0] + for key, value in query_params.items() + if key not in ["limit", "timeJoinedOrder", "paginationToken"] + } + return search_query diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 20033301b..dbf3727ce 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union -from supertokens_python.types import User +from supertokens_python.types import AccountLinkingUser from ...types import APIResponse @@ -94,7 +94,7 @@ class DashboardUsersGetResponse(APIResponse): def __init__( self, - users: Union[List[User], List[UserWithMetadata]], + users: Union[List[AccountLinkingUser], List[UserWithMetadata]], next_pagination_token: Optional[str], ): self.users = users diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 209d93eb4..bdf3995f2 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -13,25 +13,16 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List, Literal if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from supertokens_python.recipe.emailpassword import EmailPasswordRecipe -from supertokens_python.recipe.emailpassword.asyncio import ( - get_user_by_id as ep_get_user_by_id, -) from supertokens_python.recipe.passwordless import PasswordlessRecipe -from supertokens_python.recipe.passwordless.asyncio import ( - get_user_by_id as pless_get_user_by_id, -) from supertokens_python.recipe.thirdparty import ThirdPartyRecipe -from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_id as tp_get_user_by_idx, -) -from supertokens_python.types import User -from supertokens_python.utils import Awaitable, log_debug_message, normalise_email +from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.utils import log_debug_message, normalise_email from ...normalised_url_path import NormalisedURLPath from .constants import ( @@ -56,94 +47,25 @@ class UserWithMetadata: - user_id: str - time_joined: int - recipe_id: Optional[str] = None - email: Optional[str] = None - phone_number: Optional[str] = None - tp_info: Optional[Dict[str, Any]] = None + user: AccountLinkingUser first_name: Optional[str] = None last_name: Optional[str] = None - tenant_ids: List[str] def from_user( self, - user: User, + user: AccountLinkingUser, first_name: Optional[str] = None, last_name: Optional[str] = None, ): self.first_name = first_name self.last_name = last_name - - self.user_id = user.user_id - # from_user() is called in /api/users (note extra s) - # here we DashboardUsersGetResponse() doesn't maintain - # recipe id for each user on its own. That's why we need - # to set self.recipe_id here. - self.recipe_id = user.recipe_id - self.time_joined = user.time_joined - self.email = user.email - self.phone_number = user.phone_number - self.tp_info = ( - None if user.third_party_info is None else user.third_party_info.__dict__ - ) - self.tenant_ids = user.tenant_ids - - return self - - def from_dict( - self, - user_obj_dict: Dict[str, Any], - first_name: Optional[str] = None, - last_name: Optional[str] = None, - ): - self.first_name = first_name - self.last_name = last_name - - self.user_id = user_obj_dict["user_id"] - # from_dict() is used in `/api/user` where - # recipe_id is already passed seperately to - # GetUserForRecipeIdResult object - # So we set recipe_id to None here - self.recipe_id = None - self.time_joined = user_obj_dict["time_joined"] - self.tenant_ids = user_obj_dict.get("tenant_ids", []) - - self.email = user_obj_dict.get("email") - self.phone_number = user_obj_dict.get("phone_number") - self.tp_info = ( - None - if user_obj_dict.get("third_party_info") is None - else user_obj_dict["third_party_info"].__dict__ - ) - + self.user = user return self def to_json(self) -> Dict[str, Any]: - user_json: Dict[str, Any] = { - "id": self.user_id, - "timeJoined": self.time_joined, - "tenantIds": self.tenant_ids, - } - if self.tp_info is not None: - user_json["thirdParty"] = { - "id": self.tp_info["id"], - "userId": self.tp_info["user_id"], - } - if self.phone_number is not None: - user_json["phoneNumber"] = self.phone_number - if self.email is not None: - user_json["email"] = self.email - if self.first_name is not None: - user_json["firstName"] = self.first_name - if self.last_name is not None: - user_json["lastName"] = self.last_name - - if self.recipe_id is not None: - return { - "recipeId": self.recipe_id, - "user": user_json, - } + user_json = self.user.to_json() + user_json["firstName"] = self.first_name + user_json["lastName"] = self.last_name return user_json @@ -263,12 +185,6 @@ def is_valid_recipe_id(recipe_id: str) -> bool: return recipe_id in ("emailpassword", "thirdparty", "passwordless") -class GetUserForRecipeIdResult: - def __init__(self, user: UserWithMetadata, recipe: str): - self.user = user - self.recipe = recipe - - if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.types import User as EmailPasswordUser from supertokens_python.recipe.passwordless.types import User as PasswordlessUser @@ -282,53 +198,93 @@ def __init__(self, user: UserWithMetadata, recipe: str): ] -async def get_user_for_recipe_id( - user_id: str, recipe_id: str -) -> Optional[GetUserForRecipeIdResult]: - user: Optional[UserWithMetadata] = None - recipe: Optional[str] = None - - async def update_user_dict( - get_user_funcs: List[Callable[[str], Awaitable[GetUserResult]]], - recipes: List[str], +class GetUserForRecipeIdHelperResult: + def __init__( + self, user: Optional[AccountLinkingUser] = None, recipe: Optional[str] = None ): - nonlocal user, user_id, recipe + self.user = user + self.recipe = recipe - for get_user_func, recipe_id in zip(get_user_funcs, recipes): - try: - recipe_user = await get_user_func(user_id) # type: ignore - if recipe_user is not None: - user = UserWithMetadata().from_dict( - recipe_user.__dict__, first_name="", last_name="" - ) - recipe = recipe_id - break - except Exception: - pass +class GetUserForRecipeIdResult: + def __init__( + self, user: Optional[UserWithMetadata] = None, recipe: Optional[str] = None + ): + self.user = user + self.recipe = recipe - if recipe_id == EmailPasswordRecipe.recipe_id: - await update_user_dict( - [ep_get_user_by_id], - ["emailpassword"], - ) - elif recipe_id == ThirdPartyRecipe.recipe_id: - await update_user_dict( - [tp_get_user_by_idx], - ["thirdparty"], +async def get_user_for_recipe_id( + recipe_user_id: RecipeUserId, recipe_id: str, user_context: Dict[str, Any] +) -> GetUserForRecipeIdResult: + user_response = await _get_user_for_recipe_id( + recipe_user_id, recipe_id, user_context + ) + + user = None + if user_response.user is not None: + user = UserWithMetadata().from_user( + user_response.user, first_name="", last_name="" ) - elif recipe_id == PasswordlessRecipe.recipe_id: - await update_user_dict( - [pless_get_user_by_id], - ["passwordless"], + return GetUserForRecipeIdResult(user=user, recipe=user_response.recipe) + + +async def _get_user_for_recipe_id( + recipe_user_id: RecipeUserId, recipe_id: str, user_context: Dict[str, Any] +) -> GetUserForRecipeIdHelperResult: + recipe: Optional[Literal["emailpassword", "thirdparty", "passwordless"]] = None + + # Simple mock for get_user + async def mock_get_user(params: Dict[str, Any]) -> Optional[AccountLinkingUser]: + # This is a basic mock. You might want to expand this based on your needs. + raise NotImplementedError( + "This is a mock function. Implement this based on your needs." ) - if user is not None and recipe is not None: - return GetUserForRecipeIdResult(user, recipe) + user = await mock_get_user( + { + "user_id": recipe_user_id.get_as_string(), + "user_context": user_context, + } + ) - return None + if user is None: + return GetUserForRecipeIdHelperResult(user=None, recipe=None) + + login_method = next( + ( + m + for m in user.login_methods + if m.recipe_id == recipe_id + and m.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + None, + ) + + if login_method is None: + return GetUserForRecipeIdHelperResult(user=None, recipe=None) + + if recipe_id == EmailPasswordRecipe.recipe_id: + try: + EmailPasswordRecipe.get_instance() + recipe = "emailpassword" + except Exception: + pass + elif recipe_id == ThirdPartyRecipe.recipe_id: + try: + ThirdPartyRecipe.get_instance() + recipe = "thirdparty" + except Exception: + pass + elif recipe_id == PasswordlessRecipe.recipe_id: + try: + PasswordlessRecipe.get_instance() + recipe = "passwordless" + except Exception: + pass + + return GetUserForRecipeIdHelperResult(user=user, recipe=recipe) def is_recipe_initialised(recipeId: str) -> bool: diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 119b820cc..d37772c85 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -26,7 +26,7 @@ ) -from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, USER_DELETE, USERS +from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT from .exceptions import SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, @@ -42,7 +42,6 @@ from .normalised_url_path import NormalisedURLPath from .post_init_callbacks import PostSTInitCallbacks from .querier import Querier -from .types import ThirdPartyInfo, User, UsersResponse from .utils import ( get_rid_from_header, get_top_level_domain_for_same_site_resolution, @@ -351,87 +350,6 @@ async def get_user_count( # pylint: disable=no-self-use return int(response["count"]) - async def delete_user( # pylint: disable=no-self-use - self, - user_id: str, - user_context: Optional[Dict[str, Any]], - ) -> None: - querier = Querier.get_instance(None) - - cdi_version = await querier.get_api_version() - - if is_version_gte(cdi_version, "2.10"): - await querier.send_post_request( - NormalisedURLPath(USER_DELETE), - {"userId": user_id}, - user_context=user_context, - ) - - return None - raise_general_exception("Please upgrade the SuperTokens core to >= 3.7.0") - - async def get_users( # pylint: disable=no-self-use - self, - tenant_id: str, - time_joined_order: Literal["ASC", "DESC"], - limit: Union[int, None], - pagination_token: Union[str, None], - include_recipe_ids: Union[None, List[str]], - query: Union[Dict[str, str], None], - user_context: Optional[Dict[str, Any]], - ) -> UsersResponse: - - querier = Querier.get_instance(None) - params = {"timeJoinedOrder": time_joined_order} - if limit is not None: - params = {"limit": limit, **params} - if pagination_token is not None: - params = {"paginationToken": pagination_token, **params} - - include_recipe_ids_str = None - if include_recipe_ids is not None: - include_recipe_ids_str = ",".join(include_recipe_ids) - params = {"includeRecipeIds": include_recipe_ids_str, **params} - - if query is not None: - params = {**params, **query} - - response = await querier.send_get_request( - NormalisedURLPath(f"/{tenant_id}{USERS}"), params, user_context=user_context - ) - next_pagination_token = None - if "nextPaginationToken" in response: - next_pagination_token = response["nextPaginationToken"] - users_list = response["users"] - users: List[User] = [] - for user in users_list: - recipe_id = user["recipeId"] - user_obj = user["user"] - third_party = None - if "thirdParty" in user_obj: - third_party = ThirdPartyInfo( - user_obj["thirdParty"]["userId"], user_obj["thirdParty"]["id"] - ) - email = None - if "email" in user_obj: - email = user_obj["email"] - phone_number = None - if "phoneNumber" in user_obj: - phone_number = user_obj["phoneNumber"] - users.append( - User( - recipe_id, - user_obj["id"], - user_obj["timeJoined"], - email, - phone_number, - third_party, - user_obj["tenantIds"], - ) - ) - - return UsersResponse(users, next_pagination_token) - async def create_user_id_mapping( # pylint: disable=no-self-use self, supertokens_user_id: str, diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index b1557b074..911e5a6a1 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -25,7 +25,6 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import UsersResponse def get_users_oldest_first( @@ -35,11 +34,12 @@ def get_users_oldest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: +): + from supertokens_python.asyncio import get_users_oldest_first + return sync( - Supertokens.get_instance().get_users( + get_users_oldest_first( tenant_id, - "ASC", limit, pagination_token, include_recipe_ids, @@ -56,11 +56,12 @@ def get_users_newest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: +): + from supertokens_python.asyncio import get_users_newest_first + return sync( - Supertokens.get_instance().get_users( + get_users_newest_first( tenant_id, - "DESC", limit, pagination_token, include_recipe_ids, @@ -83,7 +84,9 @@ def get_user_count( def delete_user(user_id: str, user_context: Optional[Dict[str, Any]] = None) -> None: - return sync(Supertokens.get_instance().delete_user(user_id, user_context)) + from supertokens_python.asyncio import delete_user + + return sync(delete_user(user_id, user_context)) def create_user_id_mapping( diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 27691329e..b6ce766d1 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -174,8 +174,10 @@ def to_json(self) -> Dict[str, Any]: class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users: List[User] = users + def __init__( + self, users: List[AccountLinkingUser], next_pagination_token: Union[str, None] + ): + self.users: List[AccountLinkingUser] = users self.next_pagination_token: Union[str, None] = next_pagination_token diff --git a/tests/supertokens_python/test_supertokens_functions.py b/tests/supertokens_python/test_supertokens_functions.py index d33299a61..fdd8374bf 100644 --- a/tests/supertokens_python/test_supertokens_functions.py +++ b/tests/supertokens_python/test_supertokens_functions.py @@ -68,12 +68,12 @@ async def test_supertokens_functions(): # Get users in ascending order by joining time users_asc = (await st_asyncio.get_users_oldest_first("public", limit=10)).users - emails_asc = [user.email for user in users_asc] + emails_asc = [user.emails[0] for user in users_asc] assert emails_asc == emails # Get users in descending order by joining time users_desc = (await st_asyncio.get_users_newest_first("public", limit=10)).users - emails_desc = [user.email for user in users_desc] + emails_desc = [user.emails[0] for user in users_desc] assert emails_desc == emails[::-1] version = await Querier.get_instance().get_api_version() @@ -87,7 +87,7 @@ async def test_supertokens_functions(): # Again, get users in ascending order by joining time # We expect that the 2nd user (bar@example.com) must be absent. users_asc = (await st_asyncio.get_users_oldest_first("public", limit=10)).users - emails_asc = [user.email for user in users_asc] + emails_asc = [user.emails[0] for user in users_asc] assert emails[1] not in emails_asc # The 2nd user must be deleted now. if not is_version_gte(version, "2.20"): @@ -103,8 +103,8 @@ async def test_supertokens_functions(): "public", limit=10, query={"email": "baz"} ) ).users - emails_asc = [user.email for user in users_asc] - emails_desc = [user.email for user in users_desc] + emails_asc = [user.emails[0] for user in users_asc] + emails_desc = [user.emails[0] for user in users_desc] assert len(emails_asc) == 1 assert len(emails_desc) == 1 @@ -118,7 +118,7 @@ async def test_supertokens_functions(): "public", limit=10, query={"email": "john"} ) ).users - emails_asc = [user.email for user in users_asc] - emails_desc = [user.email for user in users_desc] + emails_asc = [user.emails[0] for user in users_asc] + emails_desc = [user.emails[0] for user in users_desc] assert len(emails_asc) == 0 assert len(emails_desc) == 0 diff --git a/tests/test_pagination.py b/tests/test_pagination.py index d66b9163b..fb1249e90 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -51,16 +51,16 @@ async def test_get_users_pagination(): # Get all the users (No limit) response = await get_users_newest_first("public") - assert [user.email for user in response.users] == [ + assert [user.emails[0] for user in response.users] == [ f"dummy{i}@gmail.com" for i in range(5) ][::-1] # Get only the oldest user response = await get_users_oldest_first("public", limit=1) - assert [user.email for user in response.users] == ["dummy0@gmail.com"] + assert [user.emails[0] for user in response.users] == ["dummy0@gmail.com"] # Test pagination response = await get_users_oldest_first( "public", limit=1, pagination_token=response.next_pagination_token ) - assert [user.email for user in response.users] == ["dummy1@gmail.com"] + assert [user.emails[0] for user in response.users] == ["dummy1@gmail.com"] From 83fd38265035550f22b20355a966a68b34e6b8f7 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 23 Jul 2024 12:09:51 +0530 Subject: [PATCH 006/126] adds auto init of usermetadata --- .../recipe/usermetadata/recipe.py | 6 ++-- .../recipe/usermetadata/utils.py | 2 +- supertokens_python/supertokens.py | 29 +++++++++++++++---- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index 763390c80..6d45abfe2 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -15,7 +15,7 @@ from __future__ import annotations from os import environ -from typing import List, Union, Optional, Dict, Any +from typing import List, Union, Optional, Dict, Any, TYPE_CHECKING from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse @@ -31,7 +31,9 @@ validate_and_normalise_user_input, ) from supertokens_python.recipe_module import APIHandled, RecipeModule -from supertokens_python.supertokens import AppInfo + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo from .utils import InputOverrideConfig diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 66174196c..7e059cf74 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -20,10 +20,10 @@ APIInterface, RecipeInterface, ) -from supertokens_python.supertokens import AppInfo if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + from supertokens_python.supertokens import AppInfo class InputOverrideConfig: diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index d37772c85..27b978fd2 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -255,22 +255,41 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) - from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + # from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe + # from supertokens_python.recipe.totp.recipe import TOTPRecipe multitenancy_found = False + totp_found = False + user_metadata_found = False + multi_factor_auth_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: - nonlocal multitenancy_found + nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found recipe_module = recipe(self.app_info) - if recipe_module.get_recipe_id() == MultitenancyRecipe.recipe_id: + if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True + elif recipe_module.get_recipe_id() == "usermetadata": + user_metadata_found = True + # elif recipe_module.get_recipe_id() == MultiFactorAuthRecipe.recipe_id: + # multi_factor_auth_found = True + # elif recipe_module.get_recipe_id() == TOTPRecipe.recipe_id: + # totp_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) if not multitenancy_found: - recipe = MultitenancyRecipe.init()(self.app_info) - self.recipe_modules.append(recipe) + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + + self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + if totp_found and not multi_factor_auth_found: + raise_general_exception( + "Please initialize the MultiFactorAuth recipe to use TOTP." + ) + if not user_metadata_found: + from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + + self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) self.telemetry = ( telemetry From 0afaaceb0735cfc2dd7fcc5f3561d10aa6bbc884 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 25 Jul 2024 15:13:40 +0530 Subject: [PATCH 007/126] adds types for account linking recipe --- supertokens_python/asyncio/__init__.py | 43 ++- .../recipe/accountlinking/interfaces.py | 278 ++++++++++++++++++ .../recipe/accountlinking/recipe.py | 151 ++++++++++ .../recipe/accountlinking/types.py | 124 ++++++++ .../dashboard/api/userdetails/user_delete.py | 18 +- .../recipe/dashboard/api/users_get.py | 9 +- supertokens_python/recipe/dashboard/utils.py | 15 +- supertokens_python/syncio/__init__.py | 8 +- supertokens_python/types.py | 8 - 9 files changed, 616 insertions(+), 38 deletions(-) create mode 100644 supertokens_python/recipe/accountlinking/interfaces.py create mode 100644 supertokens_python/recipe/accountlinking/recipe.py create mode 100644 supertokens_python/recipe/accountlinking/types.py diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index b51b9f05f..a0251620b 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -24,7 +24,8 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import UsersResponse +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult async def get_users_oldest_first( @@ -34,8 +35,18 @@ async def get_users_oldest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: - raise NotImplementedError("This function is not implemented") +) -> GetUsersResult: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( + tenant_id, + time_joined_order="DESC", + limit=limit, + pagination_token=pagination_token, + include_recipe_ids=include_recipe_ids, + query=query, + user_context=user_context, + ) async def get_users_newest_first( @@ -45,8 +56,18 @@ async def get_users_newest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: - raise NotImplementedError("This function is not implemented") +) -> GetUsersResult: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( + tenant_id, + time_joined_order="ASC", + limit=limit, + pagination_token=pagination_token, + include_recipe_ids=include_recipe_ids, + query=query, + user_context=user_context, + ) async def get_user_count( @@ -60,9 +81,17 @@ async def get_user_count( async def delete_user( - user_id: str, user_context: Optional[Dict[str, Any]] = None + user_id: str, + remove_all_linked_accounts: bool = True, + user_context: Optional[Dict[str, Any]] = None, ) -> None: - raise NotImplementedError("This function is not implemented") + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.delete_user( + user_id, + remove_all_linked_accounts=remove_all_linked_accounts, + user_context=user_context, + ) async def create_user_id_mapping( diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py new file mode 100644 index 000000000..bbb05ac7e --- /dev/null +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -0,0 +1,278 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional +from typing_extensions import Literal + +if TYPE_CHECKING: + from supertokens_python.types import ( + AccountLinkingUser, + RecipeUserId, + ThirdPartyInfo, + ) + +from typing import Optional + + +class AccountInfo: + def __init__( + self, + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + self.email = email + self.phone_number = phone_number + self.third_party = third_party + + +class RecipeInterface(ABC): + @abstractmethod + async def get_users( + self, + tenant_id: str, + time_joined_order: Literal["ASC", "DESC"], + limit: Optional[int], + pagination_token: Optional[str], + include_recipe_ids: Optional[List[str]], + query: Optional[Dict[str, str]], + user_context: Dict[str, Any], + ) -> GetUsersResult: + pass + + @abstractmethod + async def can_create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + pass + + @abstractmethod + async def create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + pass + + @abstractmethod + async def can_link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + ]: + pass + + @abstractmethod + async def link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + ]: + pass + + @abstractmethod + async def unlink_account( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> UnlinkAccountOkResult: + pass + + @abstractmethod + async def get_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> Optional[AccountLinkingUser]: + pass + + @abstractmethod + async def list_users_by_account_info( + self, + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool, + user_context: Dict[str, Any], + ) -> List[AccountLinkingUser]: + pass + + @abstractmethod + async def delete_user( + self, + user_id: str, + remove_all_linked_accounts: bool, + user_context: Dict[str, Any], + ) -> None: + pass + + +class GetUsersResult: + def __init__( + self, users: List[AccountLinkingUser], next_pagination_token: Optional[str] + ): + self.users = users + self.next_pagination_token = next_pagination_token + + +class CanCreatePrimaryUserOkResult: + def __init__(self, was_already_a_primary_user: bool): + self.status: Literal["OK"] = "OK" + self.was_already_a_primary_user = was_already_a_primary_user + + +class CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError: + def __init__(self, primary_user_id: str, description: str): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanCreatePrimaryUserAccountInfoAlreadyAssociatedError: + def __init__(self, primary_user_id: str, description: str): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CreatePrimaryUserOkResult: + def __init__(self, user: AccountLinkingUser, was_already_a_primary_user: bool): + self.status: Literal["OK"] = "OK" + self.user = user + self.was_already_a_primary_user = was_already_a_primary_user + + +class CreatePrimaryUserRecipeUserIdAlreadyLinkedError: + def __init__(self, primary_user_id: str, description: Optional[str] = None): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CreatePrimaryUserAccountInfoAlreadyAssociatedError: + def __init__(self, primary_user_id: str, description: Optional[str] = None): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsOkResult: + def __init__(self, accounts_already_linked: bool): + self.status: Literal["OK"] = "OK" + self.accounts_already_linked = accounts_already_linked + + +class CanLinkAccountsRecipeUserIdAlreadyLinkedError: + def __init__( + self, primary_user_id: Optional[str] = None, description: Optional[str] = None + ): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsAccountInfoAlreadyAssociatedError: + def __init__( + self, primary_user_id: Optional[str] = None, description: Optional[str] = None + ): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsInputUserNotPrimaryError: + def __init__(self, description: Optional[str] = None): + self.status: Literal[ + "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ] = "INPUT_USER_IS_NOT_A_PRIMARY_USER" + self.description = description + + +class LinkAccountsOkResult: + def __init__(self, accounts_already_linked: bool, user: AccountLinkingUser): + self.status: Literal["OK"] = "OK" + self.accounts_already_linked = accounts_already_linked + self.user = user + + +class LinkAccountsRecipeUserIdAlreadyLinkedError: + def __init__( + self, + primary_user_id: Optional[str] = None, + user: Optional[AccountLinkingUser] = None, + description: Optional[str] = None, + ): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.user = user + self.description = description + + +class LinkAccountsAccountInfoAlreadyAssociatedError: + def __init__( + self, + primary_user_id: Optional[str] = None, + user: Optional[AccountLinkingUser] = None, + description: Optional[str] = None, + ): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.user = user + self.description = description + + +class LinkAccountsInputUserNotPrimaryError: + def __init__(self, description: Optional[str] = None): + self.status: Literal[ + "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ] = "INPUT_USER_IS_NOT_A_PRIMARY_USER" + self.description = description + + +class UnlinkAccountOkResult: + def __init__(self, was_recipe_user_deleted: bool, was_linked: bool): + self.status: Literal["OK"] = "OK" + self.was_recipe_user_deleted = was_recipe_user_deleted + self.was_linked = was_linked diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py new file mode 100644 index 000000000..c83ccd3ba --- /dev/null +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from os import environ +from typing import Any, Dict, List, Union, TYPE_CHECKING, Optional, Callable +from supertokens_python.supertokens import Supertokens + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe_module import APIHandled, RecipeModule + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception + +from .types import ( + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, + AccountInfoWithRecipeIdAndUserId, + InputOverrideConfig, +) + +from .interfaces import RecipeInterface + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + from supertokens_python.types import AccountLinkingUser + from supertokens_python.recipe.session import SessionContainer + from supertokens_python.framework import BaseRequest, BaseResponse + + +class AccountLinkingRecipe(RecipeModule): + recipe_id = "accountlinking" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + on_account_linked: Optional[ + Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[AccountLinkingUser], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, + ): + super().__init__(recipe_id, app_info) + self.recipe_implementation: RecipeInterface + raise Exception("TODO: to implement") + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [] + + async def handle_api_request( + self, + request_id: str, + tenant_id: Optional[str], + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> Union[BaseResponse, None]: + raise Exception("Should never come here") + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + on_account_linked: Optional[ + Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[AccountLinkingUser], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, + ): + def func(app_info: AppInfo): + if AccountLinkingRecipe.__instance is None: + AccountLinkingRecipe.__instance = AccountLinkingRecipe( + AccountLinkingRecipe.recipe_id, + app_info, + on_account_linked, + should_do_automatic_account_linking, + override, + ) + return AccountLinkingRecipe.__instance + raise Exception( + None, + "Accountlinking recipe has already been initialised. Please check your code for bugs.", + ) + + return func + + @staticmethod + def get_instance() -> AccountLinkingRecipe: + if AccountLinkingRecipe.__instance is None: + AccountLinkingRecipe.init()(Supertokens.get_instance().app_info) + + assert AccountLinkingRecipe.__instance is not None + return AccountLinkingRecipe.__instance + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + AccountLinkingRecipe.__instance = None diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py new file mode 100644 index 000000000..fbc42977f --- /dev/null +++ b/supertokens_python/recipe/accountlinking/types.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from typing import Callable, Dict, Any, Union, Optional, List, TYPE_CHECKING +from typing_extensions import Literal +from supertokens_python.recipe.accountlinking.interfaces import ( + AccountInfo, + RecipeInterface, +) + +if TYPE_CHECKING: + from supertokens_python.types import ( + RecipeUserId, + ThirdPartyInfo, + AccountLinkingUser, + ) + from supertokens_python.recipe.session import SessionContainer + + +class AccountInfoWithRecipeId(AccountInfo): + def __init__( + self, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__(email, phone_number, third_party) + self.recipe_id = recipe_id + + +class RecipeLevelUser(AccountInfoWithRecipeId): + def __init__( + self, + tenant_ids: List[str], + time_joined: int, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__(recipe_id, email, phone_number, third_party) + self.tenant_ids = tenant_ids + self.time_joined = time_joined + self.recipe_id = recipe_id + + +class AccountInfoWithRecipeIdAndUserId(RecipeLevelUser): + def __init__( + self, + recipe_user_id: Optional[RecipeUserId], + tenant_ids: List[str], + time_joined: int, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__( + tenant_ids, time_joined, recipe_id, email, phone_number, third_party + ) + self.recipe_user_id = recipe_user_id + + +class ShouldNotAutomaticallyLink: + def __init__(self): + self.should_automatically_link = False + + +class ShouldAutomaticallyLink: + def __init__(self, should_require_verification: bool): + self.should_automatically_link = True + self.should_require_verification = should_require_verification + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + ): + self.functions = functions + + +class InputOverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + ): + self.functions = functions + + +class AccountLinkingConfig: + def __init__( + self, + on_account_linked: Callable[ + [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None + ], + should_do_automatic_account_linking: Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[AccountLinkingUser], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + ], + override: OverrideConfig, + ): + self.on_account_linked = on_account_linked + self.should_do_automatic_account_linking = should_do_automatic_account_linking + self.override = override diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py index 0e781450b..82868584c 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py @@ -12,10 +12,24 @@ async def handle_user_delete( _user_context: Dict[str, Any], ) -> UserDeleteAPIResponse: user_id = api_options.request.get_query_param("userId") + remove_all_linked_accounts_query_value = api_options.request.get_query_param( + "removeAllLinkedAccounts" + ) - if user_id is None: + if remove_all_linked_accounts_query_value is not None: + remove_all_linked_accounts_query_value = ( + remove_all_linked_accounts_query_value.strip().lower() + ) + + remove_all_linked_accounts = ( + True + if remove_all_linked_accounts_query_value is None + else remove_all_linked_accounts_query_value == "true" + ) + + if user_id is None or user_id == "": raise_bad_input_exception("Missing required parameter 'userId'") - await delete_user(user_id, _user_context) + await delete_user(user_id, remove_all_linked_accounts) return UserDeleteAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/users_get.py b/supertokens_python/recipe/dashboard/api/users_get.py index 7a3ff8452..6da1a961a 100644 --- a/supertokens_python/recipe/dashboard/api/users_get.py +++ b/supertokens_python/recipe/dashboard/api/users_get.py @@ -14,7 +14,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any, Awaitable, List, Dict, Union +from typing import TYPE_CHECKING, Any, Awaitable, List, Dict from typing_extensions import Literal from ...usermetadata import UserMetadataRecipe @@ -27,17 +27,12 @@ APIOptions, APIInterface, ) - from supertokens_python.types import APIResponse, AccountLinkingUser + from supertokens_python.types import APIResponse from supertokens_python.exceptions import GeneralError, raise_bad_input_exception from supertokens_python.asyncio import get_users_newest_first, get_users_oldest_first -class MockUsersResponse: - users: List[AccountLinkingUser] = [] - next_pagination_token: Union[str, None] = None - - async def handle_users_get_api( api_implementation: APIInterface, tenant_id: str, diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index bdf3995f2..db2206ed0 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -14,6 +14,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List, Literal +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -235,18 +236,8 @@ async def _get_user_for_recipe_id( ) -> GetUserForRecipeIdHelperResult: recipe: Optional[Literal["emailpassword", "thirdparty", "passwordless"]] = None - # Simple mock for get_user - async def mock_get_user(params: Dict[str, Any]) -> Optional[AccountLinkingUser]: - # This is a basic mock. You might want to expand this based on your needs. - raise NotImplementedError( - "This is a mock function. Implement this based on your needs." - ) - - user = await mock_get_user( - { - "user_id": recipe_user_id.get_as_string(), - "user_context": user_context, - } + user = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context ) if user is None: diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index 911e5a6a1..415539771 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -83,10 +83,14 @@ def get_user_count( ) -def delete_user(user_id: str, user_context: Optional[Dict[str, Any]] = None) -> None: +def delete_user( + user_id: str, + remove_all_linked_accounts: bool = True, + user_context: Optional[Dict[str, Any]] = None, +) -> None: from supertokens_python.asyncio import delete_user - return sync(delete_user(user_id, user_context)) + return sync(delete_user(user_id, remove_all_linked_accounts, user_context)) def create_user_id_mapping( diff --git a/supertokens_python/types.py b/supertokens_python/types.py index b6ce766d1..e54c2b589 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -173,14 +173,6 @@ def to_json(self) -> Dict[str, Any]: return res -class UsersResponse: - def __init__( - self, users: List[AccountLinkingUser], next_pagination_token: Union[str, None] - ): - self.users: List[AccountLinkingUser] = users - self.next_pagination_token: Union[str, None] = next_pagination_token - - class APIResponse(ABC): @abstractmethod def to_json(self) -> Dict[str, Any]: From 5b71f8798e0c2eb16206a23e772ebe3a2d2bf460 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 7 Aug 2024 16:53:18 +0530 Subject: [PATCH 008/126] adds normalising of input for account linking recipe --- .../recipe/accountlinking/__init__.py | 52 +++++++ .../recipe/accountlinking/interfaces.py | 2 - .../recipe/accountlinking/recipe.py | 32 +++- .../accountlinking/recipe_implementation.py | 145 ++++++++++++++++++ .../recipe/accountlinking/types.py | 6 +- .../recipe/accountlinking/utils.py | 95 ++++++++++++ 6 files changed, 319 insertions(+), 13 deletions(-) create mode 100644 supertokens_python/recipe/accountlinking/__init__.py create mode 100644 supertokens_python/recipe/accountlinking/recipe_implementation.py create mode 100644 supertokens_python/recipe/accountlinking/utils.py diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py new file mode 100644 index 000000000..71ffa8119 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Callable, Union, Optional, Dict, Any, Awaitable + +from . import types + +from . import utils +from .recipe import AccountLinkingRecipe + +InputOverrideConfig = utils.InputOverrideConfig +AccountLinkingUser = types.AccountLinkingUser +RecipeLevelUser = types.RecipeLevelUser +AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId +SessionContainer = types.SessionContainer +ShouldAutomaticallyLink = types.ShouldAutomaticallyLink +ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink + + +def init( + on_account_linked: Optional[ + Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[AccountLinkingUser], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, +): + return AccountLinkingRecipe.init( + on_account_linked, should_do_automatic_account_linking, override + ) diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index bbb05ac7e..867aeec93 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -24,8 +24,6 @@ ThirdPartyInfo, ) -from typing import Optional - class AccountInfo: def __init__( diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index c83ccd3ba..0aa8e73f0 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -14,13 +14,15 @@ from __future__ import annotations from os import environ -from typing import Any, Dict, List, Union, TYPE_CHECKING, Optional, Callable +from typing import Any, Dict, List, Union, TYPE_CHECKING, Optional, Callable, Awaitable from supertokens_python.supertokens import Supertokens from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe_module import APIHandled, RecipeModule - +from .utils import validate_and_normalise_user_input from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from .recipe_implementation import RecipeImplementation +from supertokens_python.querier import Querier from .types import ( RecipeLevelUser, @@ -48,7 +50,9 @@ def __init__( recipe_id: str, app_info: AppInfo, on_account_linked: Optional[ - Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None] + Callable[ + [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + ] ] = None, should_do_automatic_account_linking: Optional[ Callable[ @@ -59,14 +63,24 @@ def __init__( str, Dict[str, Any], ], - Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, override: Optional[InputOverrideConfig] = None, ): super().__init__(recipe_id, app_info) - self.recipe_implementation: RecipeInterface - raise Exception("TODO: to implement") + self.config = validate_and_normalise_user_input( + app_info, on_account_linked, should_do_automatic_account_linking, override + ) + recipe_implementation: RecipeInterface = RecipeImplementation( + Querier.get_instance(recipe_id) + ) + + self.recipe_implementation = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return False @@ -101,7 +115,9 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( on_account_linked: Optional[ - Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None] + Callable[ + [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + ] ] = None, should_do_automatic_account_linking: Optional[ Callable[ @@ -112,7 +128,7 @@ def init( str, Dict[str, Any], ], - Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ] ] = None, override: Optional[InputOverrideConfig] = None, diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py new file mode 100644 index 000000000..e01c1d7ef --- /dev/null +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional +from typing_extensions import Literal + + +from .interfaces import ( + RecipeInterface, + GetUsersResult, + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + UnlinkAccountOkResult, + AccountLinkingUser, + RecipeUserId, + AccountInfo, +) + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + ): + super().__init__() + self.querier = querier + + async def get_users( + self, + tenant_id: str, + time_joined_order: Literal["ASC", "DESC"], + limit: Optional[int], + pagination_token: Optional[str], + include_recipe_ids: Optional[List[str]], + query: Optional[Dict[str, str]], + user_context: Dict[str, Any], + ) -> GetUsersResult: + # Implementation for get_users + raise NotImplementedError("get_users") + + async def can_create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + # Implementation for can_create_primary_user + raise NotImplementedError("can_create_primary_user") + + async def create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + # Implementation for create_primary_user + raise NotImplementedError("create_primary_user") + + async def can_link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + ]: + # Implementation for can_link_accounts + raise NotImplementedError("can_link_accounts") + + async def link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + ]: + # Implementation for link_accounts + raise NotImplementedError("link_accounts") + + async def unlink_account( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> UnlinkAccountOkResult: + # Implementation for unlink_account + raise NotImplementedError("unlink_account") + + async def get_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> Optional[AccountLinkingUser]: + # Implementation for get_user + raise NotImplementedError("get_user") + + async def list_users_by_account_info( + self, + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool, + user_context: Dict[str, Any], + ) -> List[AccountLinkingUser]: + # Implementation for list_users_by_account_info + raise NotImplementedError("list_users_by_account_info") + + async def delete_user( + self, + user_id: str, + remove_all_linked_accounts: bool, + user_context: Dict[str, Any], + ) -> None: + # Implementation for delete_user + raise NotImplementedError("delete_user") diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index fbc42977f..49a6384db 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import Callable, Dict, Any, Union, Optional, List, TYPE_CHECKING +from typing import Callable, Dict, Any, Union, Optional, List, TYPE_CHECKING, Awaitable from typing_extensions import Literal from supertokens_python.recipe.accountlinking.interfaces import ( AccountInfo, @@ -105,7 +105,7 @@ class AccountLinkingConfig: def __init__( self, on_account_linked: Callable[ - [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], None + [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] ], should_do_automatic_account_linking: Callable[ [ @@ -115,7 +115,7 @@ def __init__( str, Dict[str, Any], ], - Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], ], override: OverrideConfig, ): diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py new file mode 100644 index 000000000..7e39bfc40 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Awaitable + +from .types import ( + AccountLinkingConfig, + AccountLinkingUser, + RecipeLevelUser, + AccountInfoWithRecipeIdAndUserId, + SessionContainer, + ShouldNotAutomaticallyLink, + ShouldAutomaticallyLink, + InputOverrideConfig, + OverrideConfig, +) + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + +async def default_on_account_linked( + _: AccountLinkingUser, __: RecipeLevelUser, ___: Dict[str, Any] +): + pass + + +_did_use_default_should_do_automatic_account_linking: bool = True + + +async def default_should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + ___: Optional[AccountLinkingUser], + ____: Optional[SessionContainer], + _____: str, + ______: Dict[str, Any], +) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + return ShouldNotAutomaticallyLink() + + +def recipe_init_defined_should_do_automatic_account_linking() -> bool: + return _did_use_default_should_do_automatic_account_linking + + +def validate_and_normalise_user_input( + _: AppInfo, + on_account_linked: Optional[ + Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[AccountLinkingUser], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Union[InputOverrideConfig, None] = None, +) -> AccountLinkingConfig: + global _did_use_default_should_do_automatic_account_linking + if override is None: + override = InputOverrideConfig() + + _did_use_default_should_do_automatic_account_linking = ( + should_do_automatic_account_linking is None + ) + + return AccountLinkingConfig( + override=OverrideConfig(functions=override.functions), + on_account_linked=( + default_on_account_linked + if on_account_linked is None + else on_account_linked + ), + should_do_automatic_account_linking=( + default_should_do_automatic_account_linking + if should_do_automatic_account_linking is None + else should_do_automatic_account_linking + ), + ) From 2db2b944ce883b0245f586b0b4fbdd55dcf54c86 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 7 Aug 2024 18:25:52 +0530 Subject: [PATCH 009/126] adds more functions to accountlinking recipe --- .pylintrc | 6 +- .../recipe/accountlinking/interfaces.py | 14 +- .../recipe/accountlinking/recipe.py | 563 +++++++++++++++++- .../recipe/accountlinking/types.py | 36 +- .../api/userdetails/user_email_verify_get.py | 2 +- .../emailverification/asyncio/__init__.py | 20 +- .../recipe/emailverification/recipe.py | 10 +- supertokens_python/types.py | 34 +- 8 files changed, 637 insertions(+), 48 deletions(-) diff --git a/.pylintrc b/.pylintrc index 25629e81f..bd9c207ce 100644 --- a/.pylintrc +++ b/.pylintrc @@ -116,7 +116,11 @@ disable=raw-checker-failed, global-statement, too-many-lines, duplicate-code, - too-many-return-statements + too-many-return-statements, + logging-not-lazy, + logging-fstring-interpolation, + consider-using-f-string + # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 867aeec93..c5ccd1f4b 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -21,22 +21,10 @@ from supertokens_python.types import ( AccountLinkingUser, RecipeUserId, - ThirdPartyInfo, + AccountInfo, ) -class AccountInfo: - def __init__( - self, - email: Optional[str] = None, - phone_number: Optional[str] = None, - third_party: Optional[ThirdPartyInfo] = None, - ): - self.email = email - self.phone_number = phone_number - self.third_party = third_party - - class RecipeInterface(ABC): @abstractmethod async def get_users( diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 0aa8e73f0..b6ddc021a 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -23,6 +23,12 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from .recipe_implementation import RecipeImplementation from supertokens_python.querier import Querier +from supertokens_python.logger import ( + log_debug_message, +) +from supertokens_python.process_state import PROCESS_STATE, ProcessState +from typing_extensions import Literal +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe from .types import ( RecipeLevelUser, @@ -30,17 +36,33 @@ ShouldNotAutomaticallyLink, AccountInfoWithRecipeIdAndUserId, InputOverrideConfig, + AccountInfoWithRecipeId, + AccountInfo, ) from .interfaces import RecipeInterface +from supertokens_python.recipe.emailverification.interfaces import ( + CreateEmailVerificationTokenOkResult, +) + if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo - from supertokens_python.types import AccountLinkingUser + from supertokens_python.types import AccountLinkingUser, LoginMethod, RecipeUserId from supertokens_python.recipe.session import SessionContainer from supertokens_python.framework import BaseRequest, BaseResponse +class EmailChangeAllowedResult: + def __init__( + self, + allowed: bool, + reason: Literal["OK", "PRIMARY_USER_CONFLICT", "ACCOUNT_TAKEOVER_RISK"], + ): + self.allowed = allowed + self.reason = reason + + class AccountLinkingRecipe(RecipeModule): recipe_id = "accountlinking" __instance = None @@ -165,3 +187,542 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") AccountLinkingRecipe.__instance = None + + async def get_primary_user_that_can_be_linked_to_recipe_user_id( + self, + tenant_id: str, + user: AccountLinkingUser, + user_context: Dict[str, Any], + ) -> Optional[AccountLinkingUser]: + # First we check if this user itself is a primary user or not. If it is, we return that. + if user.is_primary_user: + return user + + # Then, we try and find a primary user based on the email / phone number / third party ID. + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=user.login_methods[0], + do_union_of_account_info=True, + user_context=user_context, + ) + + log_debug_message( + "getPrimaryUserThatCanBeLinkedToRecipeUserId found %d matching users" + % len(users) + ) + primary_users = [u for u in users if u.is_primary_user] + log_debug_message( + "getPrimaryUserThatCanBeLinkedToRecipeUserId found %d matching primary users" + % len(primary_users) + ) + + if len(primary_users) > 1: + # This means that the new user has account info such that it's + # spread across multiple primary user IDs. In this case, even + # if we return one of them, it won't be able to be linked anyway + # cause if we did, it would mean 2 primary users would have the + # same account info. So we return None + + # This being said, with the current set of auth recipes, it should + # never come here - cause: + # ----> If the recipeuserid is a passwordless user, then it can have either a phone + # email or both. If it has just one of them, then anyway 2 primary users can't + # exist with the same phone number / email. If it has both, then the only way + # that it can have multiple primary users returned is if there is another passwordless + # primary user with the same phone number - which is not possible, cause phone + # numbers are unique across passwordless users. + # + # ----> If the input is a third party user, then it has third party info and an email. Now there can be able to primary user with the same email, but + # there can't be another thirdparty user with the same third party info (since that is unique). + # Nor can there an email password primary user with the same email along with another + # thirdparty primary user with the same email (since emails can't be the same across primary users). + # + # ----> If the input is an email password user, then it has an email. There can't be multiple primary users with the same email anyway. + raise Exception( + "You found a bug. Please report it on github.com/supertokens/supertokens-node" + ) + + return primary_users[0] if len(primary_users) > 0 else None + + async def get_oldest_user_that_can_be_linked_to_recipe_user( + self, + tenant_id: str, + user: AccountLinkingUser, + user_context: Dict[str, Any], + ) -> Optional[AccountLinkingUser]: + # First we check if this user itself is a primary user or not. If it is, we return that since it cannot be linked to anything else + if user.is_primary_user: + return user + + # Then, we try and find matching users based on the email / phone number / third party ID. + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=user.login_methods[0], + do_union_of_account_info=True, + user_context=user_context, + ) + + log_debug_message( + f"getOldestUserThatCanBeLinkedToRecipeUser found {len(users)} matching users" + ) + + # Finally select the oldest one + oldest_user = min(users, key=lambda u: u.time_joined) if users else None + return oldest_user + + async def is_sign_in_allowed( + self, + user: AccountLinkingUser, + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + tenant_id: str, + session: Optional[SessionContainer], + sign_in_verifies_login_method: bool, + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state(PROCESS_STATE.IS_SIGN_IN_ALLOWED_CALLED) + if ( + user.is_primary_user + or user.login_methods[0].verified + or sign_in_verifies_login_method + ): + return True + + return await self.is_sign_in_up_allowed_helper( + account_info=account_info, + is_verified=user.login_methods[0].verified, + session=session, + tenant_id=tenant_id, + is_sign_in=True, + user=user, + user_context=user_context, + ) + + async def is_sign_up_allowed( + self, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer], + tenant_id: str, + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state(PROCESS_STATE.IS_SIGN_UP_ALLOWED_CALLED) + if new_user.email is not None and new_user.phone_number is not None: + # We do this check cause below when we call list_users_by_account_info, + # we only pass in one of email or phone number + raise Exception("Please pass one of email or phone number, not both") + + return await self.is_sign_in_up_allowed_helper( + account_info=new_user, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, + user=None, + is_sign_in=False, + ) + + async def is_sign_in_up_allowed_helper( + self, + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + is_verified: bool, + session: Optional[SessionContainer], + tenant_id: str, + is_sign_in: bool, + user: Optional[AccountLinkingUser], + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state( + PROCESS_STATE.IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED + ) + + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=account_info, + do_union_of_account_info=True, + user_context=user_context, + ) + + if not users: + log_debug_message( + "isSignInUpAllowedHelper returning true because no user with given account info" + ) + return True + + if is_sign_in and user is None: + raise Exception( + "This should never happen: isSignInUpAllowedHelper called with isSignIn: true, user: None" + ) + + if ( + len(users) == 1 + and is_sign_in + and user is not None + and users[0].id == user.id + ): + log_debug_message( + "isSignInUpAllowedHelper returning true because this is sign in and there is only a single user with the given account info" + ) + return True + + primary_users = [u for u in users if u.is_primary_user] + + # pylint:disable=no-else-return + if not primary_users: + log_debug_message("isSignInUpAllowedHelper no primary user exists") + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + None, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "isSignInUpAllowedHelper returning true because account linking is disabled" + ) + return True + + if not should_do_account_linking.should_require_verification: + log_debug_message( + "isSignInUpAllowedHelper returning true because dev does not require email verification" + ) + return True + + should_allow = True + for curr_user in users: + if session is not None and curr_user.id == session.get_user_id( + user_context + ): + # We do not consider the current session user to be conflicting + # This can be useful in cases where the current sign in will mark the session user as verified + continue + + this_iteration_is_verified = False + if account_info.email is not None: + if ( + curr_user.login_methods[0].has_same_email_as(account_info.email) + and curr_user.login_methods[0].verified + ): + log_debug_message( + "isSignInUpAllowedHelper found same email for another user and verified" + ) + this_iteration_is_verified = True + + if account_info.phone_number is not None: + if ( + curr_user.login_methods[0].has_same_phone_number_as( + account_info.phone_number + ) + and curr_user.login_methods[0].verified + ): + log_debug_message( + "isSignInUpAllowedHelper found same phone number for another user and verified" + ) + this_iteration_is_verified = True + + if not this_iteration_is_verified: + # even if one of the users is not verified, we do not allow sign up (see why above). + # Sure, this allows attackers to create email password accounts with an email + # to block actual users from signing up, but that's ok, since those + # users will just see an email already exists error and then will try another + # login method. They can also still just go through the password reset flow + # and then gain access to their email password account (which can then be verified). + log_debug_message( + "isSignInUpAllowedHelper returning false cause one of the other recipe level users is not verified" + ) + should_allow = False + break + + ProcessState.get_instance().add_state( + PROCESS_STATE.IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS + ) + log_debug_message(f"isSignInUpAllowedHelper returning {should_allow}") + return should_allow + else: + if len(primary_users) > 1: + raise Exception( + "You have found a bug. Please report to https://github.com/supertokens/supertokens-node/issues" + ) + + primary_user = primary_users[0] + log_debug_message("isSignInUpAllowedHelper primary user found") + + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + primary_user, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "isSignInUpAllowedHelper returning true because account linking is disabled" + ) + return True + + if not should_do_account_linking.should_require_verification: + log_debug_message( + "isSignInUpAllowedHelper returning true because dev does not require email verification" + ) + return True + + if not is_verified: + log_debug_message( + "isSignInUpAllowedHelper returning false because new user's email is not verified, and primary user with the same email was found." + ) + return False + + if session is not None and primary_user.id == session.get_user_id( + user_context + ): + return True + + for login_method in primary_user.login_methods: + if login_method.email is not None: + if ( + login_method.has_same_email_as(account_info.email) + and login_method.verified + ): + log_debug_message( + "isSignInUpAllowedHelper returning true cause found same email for primary user and verified" + ) + return True + + if login_method.phone_number is not None: + if ( + login_method.has_same_phone_number_as(account_info.phone_number) + and login_method.verified + ): + log_debug_message( + "isSignInUpAllowedHelper returning true cause found same phone number for primary user and verified" + ) + return True + + log_debug_message( + "isSignInUpAllowedHelper returning false cause primary user does not have the same email or phone number that is verified" + ) + return False + + async def is_email_change_allowed( + self, + user: AccountLinkingUser, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> EmailChangeAllowedResult: + """ + The purpose of this function is to check if a recipe user ID's email + can be changed or not. There are two conditions for when it can't be changed: + - If the recipe user is a primary user, then we need to check that the new email + doesn't belong to any other primary user. If it does, we disallow the change + since multiple primary user's can't have the same account info. + + - If the recipe user is NOT a primary user, and if is_verified is false, then + we check if there exists a primary user with the same email, and if it does + we disallow the email change cause if this email is changed, and an email + verification email is sent, then the primary user may end up clicking + on the link by mistake, causing account linking to happen which can result + in account take over if this recipe user is malicious. + """ + + for tenant_id in user.tenant_ids: + existing_users_with_new_email = ( + await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=new_email), + do_union_of_account_info=False, + user_context=user_context, + ) + ) + + other_users_with_new_email = [ + u for u in existing_users_with_new_email if u.id != user.id + ] + other_primary_user_for_new_email = [ + u for u in other_users_with_new_email if u.is_primary_user + ] + + if len(other_primary_user_for_new_email) > 1: + raise Exception( + "You found a bug. Please report it on github.com/supertokens/supertokens-core" + ) + + # pylint:disable=no-else-return + if user.is_primary_user: + if other_primary_user_for_new_email: + log_debug_message( + f"isEmailChangeAllowed: returning false cause email change will lead to two primary users having same email on {tenant_id}" + ) + return EmailChangeAllowedResult( + allowed=False, reason="PRIMARY_USER_CONFLICT" + ) + + if is_verified: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary, new email is verified and doesn't belong to any other primary user" + ) + continue + + if any( + lm.has_same_email_as(new_email) and lm.verified + for lm in user.login_methods + ): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary, new email is verified in another login method and doesn't belong to any other primary user" + ) + continue + + if not other_users_with_new_email: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary and the new email doesn't belong to any other user (primary or non-primary)" + ) + continue + + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + other_users_with_new_email[0].login_methods[0] + ), + user, + session, + tenant_id, + user_context, + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause linking is disabled" + ) + continue + + if not should_do_account_linking.should_require_verification: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause linking doesn't require email verification" + ) + continue + + log_debug_message( + f"isEmailChangeAllowed: returning false because the user hasn't verified the new email address and there exists another user with it on {tenant_id} and linking requires verification" + ) + return EmailChangeAllowedResult( + allowed=False, reason="ACCOUNT_TAKEOVER_RISK" + ) + else: + if is_verified: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary and new email is verified" + ) + continue + + if user.login_methods[0].has_same_email_as(new_email): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary and new email is same as the older one" + ) + continue + + if other_primary_user_for_new_email: + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId( + recipe_id=user.login_methods[0].recipe_id, + email=user.login_methods[0].email, + recipe_user_id=user.login_methods[0].recipe_user_id, + phone_number=user.login_methods[0].phone_number, + third_party=user.login_methods[0].third_party, + ), + other_primary_user_for_new_email[0], + session, + tenant_id, + user_context, + ) + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary there exists a primary user exists with the new email, but the dev does not have account linking enabled." + ) + continue + + if not should_do_account_linking.should_require_verification: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary there exists a primary user exists with the new email, but the dev does not require email verification." + ) + continue + + log_debug_message( + "isEmailChangeAllowed: returning false cause input user is not a primary there exists a primary user exists with the new email." + ) + return EmailChangeAllowedResult( + allowed=False, reason="ACCOUNT_TAKEOVER_RISK" + ) + + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary no primary user exists with the new email" + ) + continue + + log_debug_message( + "isEmailChangeAllowed: returning true cause email change can happen on all tenants the user is part of" + ) + return EmailChangeAllowedResult(allowed=True, reason="OK") + + # pylint:disable=no-self-use + async def verify_email_for_recipe_user_if_linked_accounts_are_verified( + self, + user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], + ) -> None: + try: + EmailVerificationRecipe.get_instance_or_throw() + except Exception: + # if email verification recipe is not initialized, we do a no-op + return + + if user.is_primary_user: + recipe_user_email: Optional[str] = None + is_already_verified = False + for lm in user.login_methods: + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string(): + recipe_user_email = lm.email + is_already_verified = lm.verified + break + + if recipe_user_email is not None: + if is_already_verified: + return + should_verify_email = False + for lm in user.login_methods: + if lm.has_same_email_as(recipe_user_email) and lm.verified: + should_verify_email = True + break + + if should_verify_email: + ev_recipe = EmailVerificationRecipe.get_instance_or_throw() + resp = await ev_recipe.recipe_implementation.create_email_verification_token( + tenant_id=user.tenant_ids[0], + user_id=recipe_user_id.get_as_string(), + email=recipe_user_email, + user_context=user_context, + ) + if isinstance(resp, CreateEmailVerificationTokenOkResult): + # we purposely pass in false below cause we don't want account + # linking to happen + await ev_recipe.recipe_implementation.verify_email_using_token( + tenant_id=user.tenant_ids[0], + token=resp.token, + # attempt_account_linking=False, + user_context=user_context, + ) diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 49a6384db..e8520a605 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -25,6 +25,7 @@ RecipeUserId, ThirdPartyInfo, AccountLinkingUser, + LoginMethod, ) from supertokens_python.recipe.session import SessionContainer @@ -38,7 +39,9 @@ def __init__( third_party: Optional[ThirdPartyInfo] = None, ): super().__init__(email, phone_number, third_party) - self.recipe_id = recipe_id + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id class RecipeLevelUser(AccountInfoWithRecipeId): @@ -54,34 +57,47 @@ def __init__( super().__init__(recipe_id, email, phone_number, third_party) self.tenant_ids = tenant_ids self.time_joined = time_joined - self.recipe_id = recipe_id + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id -class AccountInfoWithRecipeIdAndUserId(RecipeLevelUser): +class AccountInfoWithRecipeIdAndUserId(AccountInfoWithRecipeId): def __init__( self, recipe_user_id: Optional[RecipeUserId], - tenant_ids: List[str], - time_joined: int, recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], email: Optional[str] = None, phone_number: Optional[str] = None, third_party: Optional[ThirdPartyInfo] = None, ): - super().__init__( - tenant_ids, time_joined, recipe_id, email, phone_number, third_party - ) + super().__init__(recipe_id, email, phone_number, third_party) self.recipe_user_id = recipe_user_id + @staticmethod + def from_account_info_or_login_method( + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + ) -> AccountInfoWithRecipeIdAndUserId: + return AccountInfoWithRecipeIdAndUserId( + recipe_id=account_info.recipe_id, + email=account_info.email, + phone_number=account_info.phone_number, + third_party=account_info.third_party, + recipe_user_id=( + account_info.recipe_user_id + if isinstance(account_info, LoginMethod) + else None + ), + ) + class ShouldNotAutomaticallyLink: def __init__(self): - self.should_automatically_link = False + pass class ShouldAutomaticallyLink: def __init__(self, should_require_verification: bool): - self.should_automatically_link = True self.should_require_verification = should_require_verification diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py index 54e0c2a9e..f0652bd94 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py @@ -24,7 +24,7 @@ async def handle_user_email_verify_get( raise_bad_input_exception("Missing required parameter 'userId'") try: - EmailVerificationRecipe.get_instance() + EmailVerificationRecipe.get_instance_or_throw() except Exception: return FeatureNotEnabledError() diff --git a/supertokens_python/recipe/emailverification/asyncio/__init__.py b/supertokens_python/recipe/emailverification/asyncio/__init__.py index c15dc121d..b0ead0877 100644 --- a/supertokens_python/recipe/emailverification/asyncio/__init__.py +++ b/supertokens_python/recipe/emailverification/asyncio/__init__.py @@ -47,7 +47,7 @@ async def create_email_verification_token( ]: if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: email_info = await recipe.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): @@ -67,7 +67,7 @@ async def verify_email_using_token( ): if user_context is None: user_context = {} - return await EmailVerificationRecipe.get_instance().recipe_implementation.verify_email_using_token( + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.verify_email_using_token( token, tenant_id, user_context ) @@ -80,7 +80,7 @@ async def is_email_verified( if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: email_info = await recipe.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): @@ -104,7 +104,7 @@ async def revoke_email_verification_tokens( if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: email_info = await recipe.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): @@ -114,7 +114,7 @@ async def revoke_email_verification_tokens( else: raise Exception("Unknown User ID provided without email") - return await EmailVerificationRecipe.get_instance().recipe_implementation.revoke_email_verification_tokens( + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.revoke_email_verification_tokens( user_id, email, tenant_id, user_context ) @@ -127,7 +127,7 @@ async def unverify_email( if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: email_info = await recipe.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): @@ -139,7 +139,7 @@ async def unverify_email( else: raise Exception("Unknown User ID provided without email") - return await EmailVerificationRecipe.get_instance().recipe_implementation.unverify_email( + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.unverify_email( user_id, email, user_context ) @@ -150,7 +150,7 @@ async def send_email( ): if user_context is None: user_context = {} - return await EmailVerificationRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( + return await EmailVerificationRecipe.get_instance_or_throw().email_delivery.ingredient_interface_impl.send_email( input_, user_context ) @@ -167,7 +167,7 @@ async def create_email_verification_link( if user_context is None: user_context = {} - recipe_instance = EmailVerificationRecipe.get_instance() + recipe_instance = EmailVerificationRecipe.get_instance_or_throw() app_info = recipe_instance.get_app_info() email_verification_token = await create_email_verification_token( @@ -203,7 +203,7 @@ async def send_email_verification_email( user_context = {} if email is None: - recipe_instance = EmailVerificationRecipe.get_instance() + recipe_instance = EmailVerificationRecipe.get_instance_or_throw() email_info = await recipe_instance.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index e02dd06ef..6ef7d3b9f 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -241,7 +241,7 @@ def callback(): return func @staticmethod - def get_instance() -> EmailVerificationRecipe: + def get_instance_or_throw() -> EmailVerificationRecipe: if EmailVerificationRecipe.__instance is not None: return EmailVerificationRecipe.__instance raise_general_exception( @@ -305,7 +305,7 @@ def __init__(self): async def fetch_value( user_id: str, _tenant_id: str, user_context: Dict[str, Any] ) -> bool: - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() email_info = await recipe.get_email_for_user_id(user_id, user_context) if isinstance(email_info, GetEmailForUserIdOkResult): @@ -395,8 +395,10 @@ async def generate_email_verify_token_post( GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError, ]: user_id = session.get_user_id(user_context) - email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id( - user_id, user_context + email_info = ( + await EmailVerificationRecipe.get_instance_or_throw().get_email_for_user_id( + user_id, user_context + ) ) tenant_id = session.get_tenant_id() diff --git a/supertokens_python/types.py b/supertokens_python/types.py index e54c2b589..6cecfc606 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -13,9 +13,10 @@ # under the License. from abc import ABC, abstractmethod -from typing import Any, Awaitable, Dict, List, TypeVar, Union +from typing import Any, Awaitable, Dict, List, TypeVar, Union, Optional from phonenumbers import format_number, parse # type: ignore import phonenumbers # type: ignore +from typing_extensions import Literal _T = TypeVar("_T") @@ -27,6 +28,11 @@ def __init__(self, recipe_user_id: str): def get_as_string(self) -> str: return self.recipe_user_id + def __eq__(self, other: Any) -> bool: + if isinstance(other, RecipeUserId): + return self.recipe_user_id == other.recipe_user_id + return False + class ThirdPartyInfo: def __init__(self, third_party_id: str, third_party_user_id: str): @@ -34,10 +40,22 @@ def __init__(self, third_party_id: str, third_party_user_id: str): self.user_id = third_party_user_id -class LoginMethod: +class AccountInfo: def __init__( self, - recipe_id: str, + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + self.email = email + self.phone_number = phone_number + self.third_party = third_party + + +class LoginMethod(AccountInfo): + def __init__( + self, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], recipe_user_id: str, tenant_ids: List[str], email: Union[str, None], @@ -46,12 +64,12 @@ def __init__( time_joined: int, verified: bool, ): - self.recipe_id = recipe_id + super().__init__(email, phone_number, third_party) + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id self.recipe_user_id = RecipeUserId(recipe_user_id) - self.tenant_ids = tenant_ids - self.email = email - self.phone_number = phone_number - self.third_party = third_party + self.tenant_ids: List[str] = tenant_ids self.time_joined = time_joined self.verified = verified From 9b1ee92135c5a972ca569721a215ea4c1aeeda90 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 7 Aug 2024 18:50:35 +0530 Subject: [PATCH 010/126] completes recipe.py for accountlinking --- .pylintrc | 3 +- .../recipe/accountlinking/recipe.py | 306 ++++++++++++++++++ 2 files changed, 308 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index bd9c207ce..91bce55e8 100644 --- a/.pylintrc +++ b/.pylintrc @@ -119,7 +119,8 @@ disable=raw-checker-failed, too-many-return-statements, logging-not-lazy, logging-fstring-interpolation, - consider-using-f-string + consider-using-f-string, + consider-using-in # Enable the message, report, category or checker with the given id(s). You can diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index b6ddc021a..197ef203d 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -63,6 +63,14 @@ def __init__( self.reason = reason +class TryLinkingByAccountInfoOrCreatePrimaryUserResult: + def __init__( + self, status: Literal["OK", "NO_LINK"], user: Optional[AccountLinkingUser] + ): + self.status = status + self.user = user + + class AccountLinkingRecipe(RecipeModule): recipe_id = "accountlinking" __instance = None @@ -726,3 +734,301 @@ async def verify_email_for_recipe_user_if_linked_accounts_are_verified( # attempt_account_linking=False, user_context=user_context, ) + + async def should_become_primary_user( + self, + user: AccountLinkingUser, + tenant_id: str, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> bool: + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + user.login_methods[0] + ), + None, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "should_become_primary_user returning false because shouldAutomaticallyLink is false" + ) + return False + + if ( + should_do_account_linking.should_require_verification + and not user.login_methods[0].verified + ): + log_debug_message( + "should_become_primary_user returning false because shouldRequireVerification is true but the login method is not verified" + ) + return False + + log_debug_message("should_become_primary_user returning true") + return True + + async def try_linking_by_account_info_or_create_primary_user( + self, + input_user: AccountLinkingUser, + session: Optional[SessionContainer], + tenant_id: str, + user_context: Dict[str, Any], + ) -> TryLinkingByAccountInfoOrCreatePrimaryUserResult: + tries = 0 + while tries < 100: + tries += 1 + primary_user_that_can_be_linked_to_the_input_user = ( + await self.get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id=tenant_id, + user=input_user, + user_context=user_context, + ) + ) + if primary_user_that_can_be_linked_to_the_input_user is not None: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: got primary user we can try linking" + ) + # we check if the input_user and primary_user_that_can_be_linked_to_the_input_user are linked based on recipeIds because the input_user obj could be outdated + if not any( + lm.recipe_user_id.get_as_string() + == input_user.login_methods[0].recipe_user_id.get_as_string() + for lm in primary_user_that_can_be_linked_to_the_input_user.login_methods + ): + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + input_user.login_methods[0] + ), + primary_user_that_can_be_linked_to_the_input_user, + session, + tenant_id, + user_context, + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldAutomaticallyLink is false" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + account_info_verified_in_prim_user = any( + ( + input_user.login_methods[0].email is not None + and lm.has_same_email_as(input_user.login_methods[0].email) + ) + or ( + input_user.login_methods[0].phone_number is not None + and lm.has_same_phone_number_as( + input_user.login_methods[0].phone_number + ) + and lm.verified + ) + for lm in primary_user_that_can_be_linked_to_the_input_user.login_methods + ) + if should_do_account_linking.should_require_verification and ( + not input_user.login_methods[0].verified + or not account_info_verified_in_prim_user + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldRequireVerification is true but the login method is not verified in the new or the primary user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking" + ) + link_accounts_result = await self.recipe_implementation.link_accounts( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + primary_user_id=primary_user_that_can_be_linked_to_the_input_user.id, + user_context=user_context, + ) + + # pylint:disable=no-else-return + if link_accounts_result.status == "OK": + log_debug_message( + "try_linking_by_account_info_or_create_primary_user successfully linked" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user already linked to another user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + else: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=input_user + ) + + oldest_user_that_can_be_linked_to_the_input_user = ( + await self.get_oldest_user_that_can_be_linked_to_recipe_user( + tenant_id=tenant_id, + user=input_user, + user_context=user_context, + ) + ) + if ( + oldest_user_that_can_be_linked_to_the_input_user is not None + and oldest_user_that_can_be_linked_to_the_input_user.id != input_user.id + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: got an older user we can try linking" + ) + should_make_older_user_primary = await self.should_become_primary_user( + oldest_user_that_can_be_linked_to_the_input_user, + tenant_id, + session, + user_context, + ) + if should_make_older_user_primary: + create_primary_user_result = await self.recipe_implementation.create_primary_user( + recipe_user_id=oldest_user_that_can_be_linked_to_the_input_user.login_methods[ + 0 + ].recipe_user_id, + user_context=user_context, + ) + if ( + create_primary_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + or create_primary_user_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + f"try_linking_by_account_info_or_create_primary_user: retrying because createPrimaryUser returned {create_primary_user_result.status}" + ) + continue + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + input_user.login_methods[0] + ), + create_primary_user_result.user, + session, + tenant_id, + user_context, + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldAutomaticallyLink is false" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + if ( + should_do_account_linking.should_require_verification + and not input_user.login_methods[0].verified + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldRequireVerification is true but the login method is not verified" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking" + ) + link_accounts_result = ( + await self.recipe_implementation.link_accounts( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + primary_user_id=create_primary_user_result.user.id, + user_context=user_context, + ) + ) + + # pylint:disable=no-else-return + if link_accounts_result.status == "OK": + log_debug_message( + "try_linking_by_account_info_or_create_primary_user successfully linked" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user already linked to another user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + else: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: trying to make the current user primary" + ) + # pylint:disable=no-else-return + if await self.should_become_primary_user( + input_user, tenant_id, session, user_context + ): + create_primary_user_result = ( + await self.recipe_implementation.create_primary_user( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + user_context=user_context, + ) + ) + + if ( + create_primary_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + or create_primary_user_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + continue + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", + user=create_primary_user_result.user, + ) + else: + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=input_user + ) + + raise Exception( + "This should never happen: ran out of retries for try_linking_by_account_info_or_create_primary_user" + ) From fbd3d889832fb588e450a7a0e346b5f02d1e8c46 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 8 Aug 2024 11:14:32 +0530 Subject: [PATCH 011/126] exposes account linking functions --- supertokens_python/asyncio/__init__.py | 11 + .../recipe/accountlinking/asyncio/__init__.py | 191 ++++++++++++++++++ .../recipe/accountlinking/recipe.py | 6 +- .../recipe/accountlinking/syncio/__init__.py | 141 +++++++++++++ supertokens_python/syncio/__init__.py | 9 + 5 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 supertokens_python/recipe/accountlinking/asyncio/__init__.py create mode 100644 supertokens_python/recipe/accountlinking/syncio/__init__.py diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index a0251620b..e0284ab54 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -26,6 +26,7 @@ ) from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult +from supertokens_python.types import AccountLinkingUser async def get_users_oldest_first( @@ -94,6 +95,16 @@ async def delete_user( ) +async def get_user( + user_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Optional[AccountLinkingUser]: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + user_id=user_id, user_context=user_context + ) + + async def create_user_id_mapping( supertokens_user_id: str, external_user_id: str, diff --git a/supertokens_python/recipe/accountlinking/asyncio/__init__.py b/supertokens_python/recipe/accountlinking/asyncio/__init__.py new file mode 100644 index 000000000..b104f37d5 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/asyncio/__init__.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from ..types import AccountInfoWithRecipeId, AccountLinkingUser, RecipeUserId +from ..recipe import AccountLinkingRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.asyncio import get_user + + +async def create_primary_user_id_or_link_accounts( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> AccountLinkingUser: + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + input_user=user, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ) + if link_res.status == "NO_LINK": + return user + assert link_res.user is not None + return link_res.user + + +async def get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id: str, + recipe_user_id: RecipeUserId, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[AccountLinkingUser]: + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + return await AccountLinkingRecipe.get_instance().get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id=tenant_id, + user=user, + user_context=user_context, + ) + + +async def can_create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.can_create_primary_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.create_primary_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def can_link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.can_link_accounts( + recipe_user_id=recipe_user_id, + primary_user_id=primary_user_id, + user_context=user_context, + ) + + +async def link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return ( + await AccountLinkingRecipe.get_instance().recipe_implementation.link_accounts( + recipe_user_id=recipe_user_id, + primary_user_id=primary_user_id, + user_context=user_context, + ) + ) + + +async def unlink_account( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return ( + await AccountLinkingRecipe.get_instance().recipe_implementation.unlink_account( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + ) + + +async def is_sign_up_allowed( + tenant_id: str, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=new_user, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, + ) + + +async def is_sign_in_allowed( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + + return await AccountLinkingRecipe.get_instance().is_sign_in_allowed( + user=user, + account_info=next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + session=session, + tenant_id=tenant_id, + sign_in_verifies_login_method=False, + user_context=user_context, + ) + + +async def is_email_change_allowed( + recipe_user_id: RecipeUserId, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Passed in recipe user id does not exist") + + res = await AccountLinkingRecipe.get_instance().is_email_change_allowed( + user=user, + new_email=new_email, + is_verified=is_verified, + session=session, + user_context=user_context, + ) + return res.allowed diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 197ef203d..5ae58eaf0 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -60,14 +60,16 @@ def __init__( reason: Literal["OK", "PRIMARY_USER_CONFLICT", "ACCOUNT_TAKEOVER_RISK"], ): self.allowed = allowed - self.reason = reason + self.reason: Literal[ + "OK", "PRIMARY_USER_CONFLICT", "ACCOUNT_TAKEOVER_RISK" + ] = reason class TryLinkingByAccountInfoOrCreatePrimaryUserResult: def __init__( self, status: Literal["OK", "NO_LINK"], user: Optional[AccountLinkingUser] ): - self.status = status + self.status: Literal["OK", "NO_LINK"] = status self.user = user diff --git a/supertokens_python/recipe/accountlinking/syncio/__init__.py b/supertokens_python/recipe/accountlinking/syncio/__init__.py new file mode 100644 index 000000000..832b5dfad --- /dev/null +++ b/supertokens_python/recipe/accountlinking/syncio/__init__.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from supertokens_python.async_to_sync_wrapper import sync + +from ..types import AccountInfoWithRecipeId, AccountLinkingUser, RecipeUserId +from supertokens_python.recipe.session import SessionContainer + + +def create_primary_user_id_or_link_accounts( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> AccountLinkingUser: + from ..asyncio import ( + create_primary_user_id_or_link_accounts as async_create_primary_user_id_or_link_accounts, + ) + + return sync( + async_create_primary_user_id_or_link_accounts( + tenant_id, recipe_user_id, session, user_context + ) + ) + + +def get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id: str, + recipe_user_id: RecipeUserId, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[AccountLinkingUser]: + from ..asyncio import ( + get_primary_user_that_can_be_linked_to_recipe_user_id as async_get_primary_user_that_can_be_linked_to_recipe_user_id, + ) + + return sync( + async_get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id, recipe_user_id, user_context + ) + ) + + +def can_create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import can_create_primary_user as async_can_create_primary_user + + return sync(async_can_create_primary_user(recipe_user_id, user_context)) + + +def create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import create_primary_user as async_create_primary_user + + return sync(async_create_primary_user(recipe_user_id, user_context)) + + +def can_link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import can_link_accounts as async_can_link_accounts + + return sync(async_can_link_accounts(recipe_user_id, primary_user_id, user_context)) + + +def link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import link_accounts as async_link_accounts + + return sync(async_link_accounts(recipe_user_id, primary_user_id, user_context)) + + +def unlink_account( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import unlink_account as async_unlink_account + + return sync(async_unlink_account(recipe_user_id, user_context)) + + +def is_sign_up_allowed( + tenant_id: str, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_sign_up_allowed as async_is_sign_up_allowed + + return sync( + async_is_sign_up_allowed( + tenant_id, new_user, is_verified, session, user_context + ) + ) + + +def is_sign_in_allowed( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_sign_in_allowed as async_is_sign_in_allowed + + return sync( + async_is_sign_in_allowed(tenant_id, recipe_user_id, session, user_context) + ) + + +def is_email_change_allowed( + recipe_user_id: RecipeUserId, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_email_change_allowed as async_is_email_change_allowed + + return sync( + async_is_email_change_allowed( + recipe_user_id, new_email, is_verified, session, user_context + ) + ) diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index 415539771..85f9d6c89 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -25,6 +25,7 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) +from supertokens_python.types import AccountLinkingUser def get_users_oldest_first( @@ -93,6 +94,14 @@ def delete_user( return sync(delete_user(user_id, remove_all_linked_accounts, user_context)) +def get_user( + user_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Optional[AccountLinkingUser]: + from supertokens_python.asyncio import get_user as async_get_user + + return sync(async_get_user(user_id, user_context)) + + def create_user_id_mapping( supertokens_user_id: str, external_user_id: str, From d091f6dd20d02bb365088152d85d6d13aaaad9fa Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 8 Aug 2024 11:50:09 +0530 Subject: [PATCH 012/126] adds recipe impl for account linking --- .pylintrc | 3 +- .../recipe/accountlinking/recipe.py | 8 +- .../accountlinking/recipe_implementation.py | 254 ++++++++++++++++-- .../recipe/accountlinking/types.py | 13 + supertokens_python/types.py | 30 +++ 5 files changed, 283 insertions(+), 25 deletions(-) diff --git a/.pylintrc b/.pylintrc index 91bce55e8..bcab15182 100644 --- a/.pylintrc +++ b/.pylintrc @@ -120,7 +120,8 @@ disable=raw-checker-failed, logging-not-lazy, logging-fstring-interpolation, consider-using-f-string, - consider-using-in + consider-using-in, + no-else-return # Enable the message, report, category or checker with the given id(s). You can diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 5ae58eaf0..b37885be3 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -105,7 +105,7 @@ def __init__( app_info, on_account_linked, should_do_automatic_account_linking, override ) recipe_implementation: RecipeInterface = RecipeImplementation( - Querier.get_instance(recipe_id) + Querier.get_instance(recipe_id), self, self.config ) self.recipe_implementation = ( @@ -376,7 +376,6 @@ async def is_sign_in_up_allowed_helper( primary_users = [u for u in users if u.is_primary_user] - # pylint:disable=no-else-return if not primary_users: log_debug_message("isSignInUpAllowedHelper no primary user exists") should_do_account_linking = ( @@ -568,7 +567,6 @@ async def is_email_change_allowed( "You found a bug. Please report it on github.com/supertokens/supertokens-core" ) - # pylint:disable=no-else-return if user.is_primary_user: if other_primary_user_for_new_email: log_debug_message( @@ -855,7 +853,6 @@ async def try_linking_by_account_info_or_create_primary_user( user_context=user_context, ) - # pylint:disable=no-else-return if link_accounts_result.status == "OK": log_debug_message( "try_linking_by_account_info_or_create_primary_user successfully linked" @@ -969,7 +966,6 @@ async def try_linking_by_account_info_or_create_primary_user( ) ) - # pylint:disable=no-else-return if link_accounts_result.status == "OK": log_debug_message( "try_linking_by_account_info_or_create_primary_user successfully linked" @@ -1004,7 +1000,7 @@ async def try_linking_by_account_info_or_create_primary_user( log_debug_message( "try_linking_by_account_info_or_create_primary_user: trying to make the current user primary" ) - # pylint:disable=no-else-return + if await self.should_become_primary_user( input_user, tenant_id, session, user_context ): diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index e01c1d7ef..8db051480 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -39,18 +39,25 @@ RecipeUserId, AccountInfo, ) +from supertokens_python.normalised_url_path import NormalisedURLPath +from .types import AccountLinkingConfig, RecipeLevelUser if TYPE_CHECKING: from supertokens_python.querier import Querier + from .recipe import AccountLinkingRecipe class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, + recipe_instance: AccountLinkingRecipe, + config: AccountLinkingConfig, ): super().__init__() self.querier = querier + self.recipe_instance = recipe_instance + self.config = config async def get_users( self, @@ -62,8 +69,27 @@ async def get_users( query: Optional[Dict[str, str]], user_context: Dict[str, Any], ) -> GetUsersResult: - # Implementation for get_users - raise NotImplementedError("get_users") + include_recipe_ids_str = None + if include_recipe_ids is not None: + include_recipe_ids_str = ",".join(include_recipe_ids) + + params = { + "includeRecipeIds": include_recipe_ids_str, + "timeJoinedOrder": time_joined_order, + "limit": limit, + "paginationToken": pagination_token, + } + if query: + params.update(query) + + response = await self.querier.send_get_request( + NormalisedURLPath(f"/{tenant_id or 'public'}/users"), params, user_context + ) + + return GetUsersResult( + users=[AccountLinkingUser.from_json(u) for u in response["users"]], + next_pagination_token=response.get("nextPaginationToken"), + ) async def can_create_primary_user( self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] @@ -72,8 +98,32 @@ async def can_create_primary_user( CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, ]: - # Implementation for can_create_primary_user - raise NotImplementedError("can_create_primary_user") + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/accountlinking/user/primary/check"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + + if response["status"] == "OK": + return CanCreatePrimaryUserOkResult(response["wasAlreadyAPrimaryUser"]) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanCreatePrimaryUserAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + else: + raise Exception(f"Unknown response status: {response['status']}") async def create_primary_user( self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] @@ -82,8 +132,35 @@ async def create_primary_user( CreatePrimaryUserRecipeUserIdAlreadyLinkedError, CreatePrimaryUserAccountInfoAlreadyAssociatedError, ]: - # Implementation for create_primary_user - raise NotImplementedError("create_primary_user") + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/primary"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + + if response["status"] == "OK": + return CreatePrimaryUserOkResult( + AccountLinkingUser.from_json(response["user"]), + response["wasAlreadyAPrimaryUser"], + ) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CreatePrimaryUserRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CreatePrimaryUserAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + else: + raise Exception(f"Unknown response status: {response['status']}") async def can_link_accounts( self, @@ -96,8 +173,35 @@ async def can_link_accounts( CanLinkAccountsAccountInfoAlreadyAssociatedError, CanLinkAccountsInputUserNotPrimaryError, ]: - # Implementation for can_link_accounts - raise NotImplementedError("can_link_accounts") + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/accountlinking/user/link/check"), + { + "recipeUserId": recipe_user_id.get_as_string(), + "primaryUserId": primary_user_id, + }, + user_context, + ) + + if response["status"] == "OK": + return CanLinkAccountsOkResult(response["accountsAlreadyLinked"]) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanLinkAccountsRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanLinkAccountsAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + elif response["status"] == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return CanLinkAccountsInputUserNotPrimaryError(response["description"]) + else: + raise Exception(f"Unknown response status: {response['status']}") async def link_accounts( self, @@ -110,20 +214,113 @@ async def link_accounts( LinkAccountsAccountInfoAlreadyAssociatedError, LinkAccountsInputUserNotPrimaryError, ]: - # Implementation for link_accounts - raise NotImplementedError("link_accounts") + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/link"), + { + "recipeUserId": recipe_user_id.get_as_string(), + "primaryUserId": primary_user_id, + }, + user_context, + ) + + if response["status"] in [ + "OK", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + ]: + response["user"] = AccountLinkingUser.from_json(response["user"]) + + if response["status"] == "OK": + user = response["user"] + if not response["accountsAlreadyLinked"]: + await self.recipe_instance.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + updated_user = await self.get_user( + user_id=primary_user_id, + user_context=user_context, + ) + if updated_user is None: + raise Exception("This error should never be thrown") + user = updated_user + + login_method_info = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == recipe_user_id.get_as_string() + ), + None, + ) + if login_method_info is None: + raise Exception("This error should never be thrown") + + await self.config.on_account_linked( + user, + RecipeLevelUser.from_login_method(login_method_info), + user_context, + ) + + response["user"] = user + + if response["status"] == "OK": + return LinkAccountsOkResult( + user=response["user"], + accounts_already_linked=response["accountsAlreadyLinked"], + ) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkAccountsRecipeUserIdAlreadyLinkedError( + primary_user_id=response["primaryUserId"], + description=response["description"], + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkAccountsAccountInfoAlreadyAssociatedError( + primary_user_id=response["primaryUserId"], + description=response["description"], + ) + elif response["status"] == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return LinkAccountsInputUserNotPrimaryError( + description=response["description"], + ) + else: + raise Exception(f"Unknown response status: {response['status']}") async def unlink_account( self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] ) -> UnlinkAccountOkResult: - # Implementation for unlink_account - raise NotImplementedError("unlink_account") + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/unlink"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + return UnlinkAccountOkResult( + response["wasRecipeUserDeleted"], response["wasLinked"] + ) async def get_user( self, user_id: str, user_context: Dict[str, Any] ) -> Optional[AccountLinkingUser]: - # Implementation for get_user - raise NotImplementedError("get_user") + response = await self.querier.send_get_request( + NormalisedURLPath("/user/id"), + { + "userId": user_id, + }, + user_context, + ) + if response["status"] == "OK": + return AccountLinkingUser.from_json(response["user"]) + return None async def list_users_by_account_info( self, @@ -132,8 +329,23 @@ async def list_users_by_account_info( do_union_of_account_info: bool, user_context: Dict[str, Any], ) -> List[AccountLinkingUser]: - # Implementation for list_users_by_account_info - raise NotImplementedError("list_users_by_account_info") + params = { + "email": account_info.email, + "phoneNumber": account_info.phone_number, + "doUnionOfAccountInfo": do_union_of_account_info, + } + + if account_info.third_party: + params["thirdPartyId"] = account_info.third_party.id + params["thirdPartyUserId"] = account_info.third_party.user_id + + response = await self.querier.send_get_request( + NormalisedURLPath(f"/{tenant_id or 'public'}/users/by-accountinfo"), + params, + user_context, + ) + + return [AccountLinkingUser.from_json(u) for u in response["users"]] async def delete_user( self, @@ -141,5 +353,11 @@ async def delete_user( remove_all_linked_accounts: bool, user_context: Dict[str, Any], ) -> None: - # Implementation for delete_user - raise NotImplementedError("delete_user") + await self.querier.send_post_request( + NormalisedURLPath("/user/remove"), + { + "userId": user_id, + "removeAllLinkedAccounts": remove_all_linked_accounts, + }, + user_context, + ) diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index e8520a605..0b380bac7 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -61,6 +61,19 @@ def __init__( "emailpassword", "thirdparty", "passwordless" ] = recipe_id + @staticmethod + def from_login_method( + login_method: LoginMethod, + ) -> RecipeLevelUser: + return RecipeLevelUser( + tenant_ids=login_method.tenant_ids, + time_joined=login_method.time_joined, + recipe_id=login_method.recipe_id, + email=login_method.email, + phone_number=login_method.phone_number, + third_party=login_method.third_party, + ) + class AccountInfoWithRecipeIdAndUserId(AccountInfoWithRecipeId): def __init__( diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 6cecfc606..4a753e885 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -117,6 +117,23 @@ def to_json(self) -> Dict[str, Any]: "verified": self.verified, } + @staticmethod + def from_json(json: Dict[str, Any]) -> "LoginMethod": + return LoginMethod( + recipe_id=json["recipeId"], + recipe_user_id=json["recipeUserId"], + tenant_ids=json["tenantIds"], + email=json["email"], + phone_number=json["phoneNumber"], + third_party=( + ThirdPartyInfo(json["thirdParty"]["id"], json["thirdParty"]["userId"]) + if json["thirdParty"] + else None + ), + time_joined=json["timeJoined"], + verified=json["verified"], + ) + class AccountLinkingUser: def __init__( @@ -151,6 +168,19 @@ def to_json(self) -> Dict[str, Any]: "timeJoined": self.time_joined, } + @staticmethod + def from_json(json: Dict[str, Any]) -> "AccountLinkingUser": + return AccountLinkingUser( + user_id=json["id"], + is_primary_user=json["isPrimaryUser"], + tenant_ids=json["tenantIds"], + emails=json["emails"], + phone_numbers=json["phoneNumbers"], + third_party=json["thirdParty"], + login_methods=[LoginMethod.from_json(lm) for lm in json["loginMethods"]], + time_joined=json["timeJoined"], + ) + class User: def __init__( From 9789fd3ece0775faca84ed418538596493ce8cb2 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 8 Aug 2024 15:05:32 +0530 Subject: [PATCH 013/126] modifies email verification recipe fully --- .pylintrc | 3 +- CHANGELOG.md | 5 + .../recipe/accountlinking/recipe.py | 28 +- .../api/userdetails/user_email_verify_get.py | 5 +- .../api/userdetails/user_email_verify_put.py | 9 +- .../user_email_verify_token_post.py | 8 +- .../recipe/emailpassword/recipe.py | 5 +- .../recipe/emailverification/__init__.py | 4 +- .../emailverification/asyncio/__init__.py | 54 ++- .../backward_compatibility/__init__.py | 8 +- .../emailverification/ev_claim_validators.py | 14 - .../recipe/emailverification/interfaces.py | 61 +-- .../recipe/emailverification/recipe.py | 359 ++++++++++++++---- .../recipe_implementation.py | 90 ++++- .../emailverification/syncio/__init__.py | 37 +- .../recipe/emailverification/types.py | 12 +- .../recipe/emailverification/utils.py | 8 +- .../recipe/passwordless/api/implementation.py | 6 +- .../recipe/passwordless/recipe.py | 5 +- .../recipe/session/asyncio/__init__.py | 6 +- .../claim_base_classes/boolean_claim.py | 4 +- .../claim_base_classes/primitive_claim.py | 4 +- .../recipe/session/interfaces.py | 7 +- .../recipe/session/recipe_implementation.py | 8 +- .../recipe/session/session_class.py | 11 +- .../session/session_request_functions.py | 6 +- .../recipe/thirdparty/api/implementation.py | 5 +- .../recipe/thirdparty/recipe.py | 5 - tests/auth-react/django3x/polls/views.py | 4 +- tests/auth-react/fastapi-server/app.py | 2 +- tests/auth-react/flask-server/app.py | 2 +- tests/emailpassword/test_emailverify.py | 7 +- .../input_validation/test_input_validation.py | 5 +- tests/passwordless/test_emaildelivery.py | 3 +- tests/sessions/claims/test_assert_claims.py | 2 +- .../claims/test_primitive_array_claim.py | 45 +-- tests/sessions/claims/test_primitive_claim.py | 35 +- tests/sessions/claims/utils.py | 5 +- 38 files changed, 601 insertions(+), 286 deletions(-) delete mode 100644 supertokens_python/recipe/emailverification/ev_claim_validators.py diff --git a/.pylintrc b/.pylintrc index bcab15182..814178ce3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -121,7 +121,8 @@ disable=raw-checker-failed, logging-fstring-interpolation, consider-using-f-string, consider-using-in, - no-else-return + no-else-return, + no-self-use # Enable the message, report, category or checker with the given id(s). You can diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d8edaf4..9d6605b1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +### Breaking changes +- `supertokens_python.recipe.emailverification.types.User` has been renamed to `supertokens_python.recipe.emailverification.types.EmailVerificationUser` +- The user object has been changed to be a global one, containing information about all emails, phone numbers, third party info and login methods associated with that user. +- Type of `get_email_for_user_id` in `emailverification.init` has changed + ## [0.24.0] - 2024-07-31 ### Changes diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index b37885be3..3f4a06292 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -28,7 +28,6 @@ ) from supertokens_python.process_state import PROCESS_STATE, ProcessState from typing_extensions import Literal -from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe from .types import ( RecipeLevelUser, @@ -42,15 +41,14 @@ from .interfaces import RecipeInterface -from supertokens_python.recipe.emailverification.interfaces import ( - CreateEmailVerificationTokenOkResult, -) - if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo from supertokens_python.types import AccountLinkingUser, LoginMethod, RecipeUserId from supertokens_python.recipe.session import SessionContainer from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.emailverification.recipe import ( + EmailVerificationRecipe, + ) class EmailChangeAllowedResult: @@ -114,6 +112,13 @@ def __init__( else self.config.override.functions(recipe_implementation) ) + self.email_verification_recipe: EmailVerificationRecipe | None = None + + def register_email_verification_recipe( + self, email_verification_recipe: EmailVerificationRecipe + ): + self.email_verification_recipe = email_verification_recipe + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return False @@ -693,10 +698,7 @@ async def verify_email_for_recipe_user_if_linked_accounts_are_verified( recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> None: - try: - EmailVerificationRecipe.get_instance_or_throw() - except Exception: - # if email verification recipe is not initialized, we do a no-op + if self.email_verification_recipe is None: return if user.is_primary_user: @@ -718,20 +720,20 @@ async def verify_email_for_recipe_user_if_linked_accounts_are_verified( break if should_verify_email: - ev_recipe = EmailVerificationRecipe.get_instance_or_throw() + ev_recipe = self.email_verification_recipe.get_instance_or_throw() resp = await ev_recipe.recipe_implementation.create_email_verification_token( tenant_id=user.tenant_ids[0], - user_id=recipe_user_id.get_as_string(), + recipe_user_id=recipe_user_id, email=recipe_user_email, user_context=user_context, ) - if isinstance(resp, CreateEmailVerificationTokenOkResult): + if resp.status == "OK": # we purposely pass in false below cause we don't want account # linking to happen await ev_recipe.recipe_implementation.verify_email_using_token( tenant_id=user.tenant_ids[0], token=resp.token, - # attempt_account_linking=False, + attempt_account_linking=False, user_context=user_context, ) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py index f0652bd94..1b8ff544f 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py @@ -7,6 +7,7 @@ UserEmailVerifyGetAPIResponse, FeatureNotEnabledError, ) +from supertokens_python.types import RecipeUserId from typing import Union, Dict, Any @@ -28,5 +29,7 @@ async def handle_user_email_verify_get( except Exception: return FeatureNotEnabledError() - is_verified = await is_email_verified(user_id, user_context=user_context) + is_verified = await is_email_verified( + RecipeUserId(user_id), user_context=user_context + ) return UserEmailVerifyGetAPIResponse(is_verified) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py index fd4443ae1..e7a1e4700 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py @@ -11,6 +11,8 @@ VerifyEmailUsingTokenInvalidTokenError, ) +from supertokens_python.types import RecipeUserId + from ...interfaces import ( APIInterface, APIOptions, @@ -40,7 +42,10 @@ async def handle_user_email_verify_put( if verified: token_response = await create_email_verification_token( - tenant_id=tenant_id, user_id=user_id, email=None, user_context=user_context + tenant_id=tenant_id, + recipe_user_id=RecipeUserId(user_id), + email=None, + user_context=user_context, ) if isinstance( @@ -57,6 +62,6 @@ async def handle_user_email_verify_put( raise Exception("Should not come here") else: - await unverify_email(user_id, user_context=user_context) + await unverify_email(RecipeUserId(user_id), user_context=user_context) return UserEmailVerifyPutAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py index b00b0a84e..1392915c1 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py @@ -13,6 +13,8 @@ UserEmailVerifyTokenPostAPIEmailAlreadyVerifiedErrorResponse, ) +from supertokens_python.types import RecipeUserId + async def handle_email_verify_token_post( _api_interface: APIInterface, @@ -32,7 +34,11 @@ async def handle_email_verify_token_post( ) res = await send_email_verification_email( - tenant_id=tenant_id, user_id=user_id, email=None, user_context=user_context + tenant_id=tenant_id, + user_id=user_id, + recipe_user_id=RecipeUserId(user_id), + email=None, + user_context=user_context, ) if isinstance(res, SendEmailVerificationEmailAlreadyVerifiedError): diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index f4b681eac..abe38a7c6 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -43,7 +43,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.querier import Querier -from supertokens_python.recipe.emailverification import EmailVerificationRecipe from .api import ( handle_email_exists_api, @@ -118,9 +117,7 @@ def get_emailpassword_config() -> EmailPasswordConfig: ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) + pass PostSTInitCallbacks.add_post_init_callback(callback) diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index 78d2134cc..ba1f67fc3 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -43,12 +43,12 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> Callable[[AppInfo], RecipeModule]: return EmailVerificationRecipe.init( mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) diff --git a/supertokens_python/recipe/emailverification/asyncio/__init__.py b/supertokens_python/recipe/emailverification/asyncio/__init__.py index b0ead0877..c2f199067 100644 --- a/supertokens_python/recipe/emailverification/asyncio/__init__.py +++ b/supertokens_python/recipe/emailverification/asyncio/__init__.py @@ -28,7 +28,7 @@ ) from supertokens_python.recipe.emailverification.types import EmailTemplateVars from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe - +from supertokens_python.types import RecipeUserId from supertokens_python.recipe.emailverification.utils import get_email_verify_link from supertokens_python.recipe.emailverification.types import ( VerificationEmailTemplateVars, @@ -38,7 +38,7 @@ async def create_email_verification_token( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ @@ -49,7 +49,9 @@ async def create_email_verification_token( user_context = {} recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -58,22 +60,25 @@ async def create_email_verification_token( raise Exception("Unknown User ID provided without email") return await recipe.recipe_implementation.create_email_verification_token( - user_id, email, tenant_id, user_context + recipe_user_id, email, tenant_id, user_context ) async def verify_email_using_token( - tenant_id: str, token: str, user_context: Union[None, Dict[str, Any]] = None + tenant_id: str, + token: str, + attempt_account_linking: bool = True, + user_context: Union[None, Dict[str, Any]] = None, ): if user_context is None: user_context = {} return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.verify_email_using_token( - token, tenant_id, user_context + token, tenant_id, attempt_account_linking, user_context ) async def is_email_verified( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): @@ -82,7 +87,9 @@ async def is_email_verified( recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -91,13 +98,13 @@ async def is_email_verified( raise Exception("Unknown User ID provided without email") return await recipe.recipe_implementation.is_email_verified( - user_id, email, user_context + recipe_user_id, email, user_context ) async def revoke_email_verification_tokens( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ) -> RevokeEmailVerificationTokensOkResult: @@ -106,7 +113,9 @@ async def revoke_email_verification_tokens( recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -115,12 +124,12 @@ async def revoke_email_verification_tokens( raise Exception("Unknown User ID provided without email") return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.revoke_email_verification_tokens( - user_id, email, tenant_id, user_context + recipe_user_id, email, tenant_id, user_context ) async def unverify_email( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): @@ -129,7 +138,9 @@ async def unverify_email( recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -140,7 +151,7 @@ async def unverify_email( raise Exception("Unknown User ID provided without email") return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.unverify_email( - user_id, email, user_context + recipe_user_id, email, user_context ) @@ -157,7 +168,7 @@ async def send_email( async def create_email_verification_link( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ) -> Union[ @@ -171,7 +182,7 @@ async def create_email_verification_link( app_info = recipe_instance.get_app_info() email_verification_token = await create_email_verification_token( - tenant_id, user_id, email, user_context + tenant_id, recipe_user_id, email, user_context ) if isinstance( email_verification_token, CreateEmailVerificationTokenEmailAlreadyVerifiedError @@ -193,6 +204,7 @@ async def create_email_verification_link( async def send_email_verification_email( tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ) -> Union[ @@ -205,7 +217,9 @@ async def send_email_verification_email( if email is None: recipe_instance = EmailVerificationRecipe.get_instance_or_throw() - email_info = await recipe_instance.get_email_for_user_id(user_id, user_context) + email_info = await recipe_instance.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -214,7 +228,7 @@ async def send_email_verification_email( raise Exception("Unknown User ID provided without email") email_verification_link = await create_email_verification_link( - tenant_id, user_id, email, user_context + tenant_id, recipe_user_id, email, user_context ) if isinstance( @@ -224,7 +238,7 @@ async def send_email_verification_email( await send_email( VerificationEmailTemplateVars( - user=VerificationEmailTemplateVarsUser(user_id, email), + user=VerificationEmailTemplateVarsUser(user_id, recipe_user_id, email), email_verify_link=email_verification_link.link, tenant_id=tenant_id, ), diff --git a/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py b/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py index b78891e3f..1214a8010 100644 --- a/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py +++ b/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py @@ -19,7 +19,7 @@ from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryInterface from supertokens_python.logger import log_debug_message from supertokens_python.recipe.emailverification.types import ( - User, + EmailVerificationUser, VerificationEmailTemplateVars, ) from supertokens_python.supertokens import AppInfo @@ -27,7 +27,7 @@ async def create_and_send_email_using_supertokens_service( - app_info: AppInfo, user: User, email_verification_url: str + app_info: AppInfo, user: EmailVerificationUser, email_verification_url: str ) -> None: if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): return @@ -62,7 +62,9 @@ async def send_email( user_context: Dict[str, Any], ) -> None: try: - email_user = User(template_vars.user.id, template_vars.user.email) + email_user = EmailVerificationUser( + template_vars.user.recipe_user_id, template_vars.user.email + ) await create_and_send_email_using_supertokens_service( self.app_info, email_user, template_vars.email_verify_link ) diff --git a/supertokens_python/recipe/emailverification/ev_claim_validators.py b/supertokens_python/recipe/emailverification/ev_claim_validators.py deleted file mode 100644 index dd5f414fc..000000000 --- a/supertokens_python/recipe/emailverification/ev_claim_validators.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from __future__ import annotations diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index bfa0b3e09..f1501e357 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -17,52 +17,53 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import APIResponse, GeneralErrorResponse, RecipeUserId from ...supertokens import AppInfo from ..session.interfaces import SessionContainer +from typing_extensions import Literal if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse - from .types import User, VerificationEmailTemplateVars + from .types import EmailVerificationUser, VerificationEmailTemplateVars from .utils import EmailVerificationConfig class CreateEmailVerificationTokenOkResult: - status = "OK" + status: Literal["OK"] = "OK" def __init__(self, token: str): self.token = token class CreateEmailVerificationTokenEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class CreateEmailVerificationLinkEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class CreateEmailVerificationLinkOkResult: - status = "OK" + status: Literal["OK"] = "OK" def __init__(self, link: str): self.link = link class SendEmailVerificationEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class SendEmailVerificationEmailOkResult: - status = "OK" + status: Literal["OK"] = "OK" class VerifyEmailUsingTokenOkResult: - status = "OK" + status: Literal["OK"] = "OK" - def __init__(self, user: User): + def __init__(self, user: EmailVerificationUser): self.user = user @@ -84,7 +85,11 @@ def __init__(self): @abstractmethod async def create_email_verification_token( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> Union[ CreateEmailVerificationTokenOkResult, CreateEmailVerificationTokenEmailAlreadyVerifiedError, @@ -93,25 +98,33 @@ async def create_email_verification_token( @abstractmethod async def verify_email_using_token( - self, token: str, tenant_id: str, user_context: Dict[str, Any] + self, + token: str, + tenant_id: str, + attempt_account_linking: bool, + user_context: Dict[str, Any], ) -> Union[VerifyEmailUsingTokenOkResult, VerifyEmailUsingTokenInvalidTokenError]: pass @abstractmethod async def is_email_verified( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> bool: pass @abstractmethod async def revoke_email_verification_tokens( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> RevokeEmailVerificationTokensOkResult: pass @abstractmethod async def unverify_email( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> UnverifyEmailOkResult: pass @@ -137,15 +150,15 @@ def __init__( class EmailVerifyPostOkResult(APIResponse): - def __init__(self, user: User): + def __init__( + self, user: EmailVerificationUser, new_session: Optional[SessionContainer] + ): self.user = user + self.new_session = new_session self.status = "OK" def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "user": {"id": self.user.user_id, "email": self.user.email}, - } + return {"status": self.status} class EmailVerifyPostInvalidTokenError(APIResponse): @@ -157,9 +170,10 @@ def to_json(self) -> Dict[str, Any]: class IsEmailVerifiedGetOkResult(APIResponse): - def __init__(self, is_verified: bool): + def __init__(self, is_verified: bool, new_session: Optional[SessionContainer]): self.status = "OK" self.is_verified = is_verified + self.new_session = new_session def to_json(self) -> Dict[str, Any]: return {"status": self.status, "isVerified": self.is_verified} @@ -174,8 +188,9 @@ def to_json(self) -> Dict[str, Any]: class GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError(APIResponse): - def __init__(self): + def __init__(self, new_session: Optional[SessionContainer]): self.status = "EMAIL_ALREADY_VERIFIED_ERROR" + self.new_session = new_session def to_json(self) -> Dict[str, Any]: return {"status": self.status} @@ -237,7 +252,7 @@ class UnknownUserIdError(Exception): TypeGetEmailForUserIdFunction = Callable[ - [str, Dict[str, Any]], + [RecipeUserId, Dict[str, Any]], Awaitable[ Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError] ], diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 6ef7d3b9f..670adb84f 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -32,9 +32,9 @@ from ...ingredients.emaildelivery.types import EmailDeliveryConfig from ...logger import log_debug_message from ...post_init_callbacks import PostSTInitCallbacks -from ...types import MaybeAwaitable from ...utils import get_timestamp_ms from ..session import SessionRecipe +from ..session.asyncio import revoke_all_sessions_for_user, create_new_session from ..session.claim_base_classes.boolean_claim import ( BooleanClaim, BooleanClaimValidators, @@ -67,6 +67,8 @@ from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo + from supertokens_python.types import RecipeUserId + from ...types import AccountLinkingUser, MaybeAwaitable from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier @@ -90,7 +92,7 @@ def __init__( ingredients: EmailVerificationIngredients, mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> None: super().__init__(recipe_id, app_info) @@ -98,12 +100,13 @@ def __init__( app_info, mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) recipe_implementation = RecipeImplementation( - Querier.get_instance(recipe_id), self.config + Querier.get_instance(recipe_id), + self.get_email_for_recipe_user_id, ) self.recipe_implementation = ( recipe_implementation @@ -126,10 +129,6 @@ def __init__( else: self.email_delivery = email_delivery_ingredient - self.get_email_for_user_id_funcs_from_other_recipes: List[ - TypeGetEmailForUserIdFunction - ] = [] - def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and isinstance( err, SuperTokensEmailVerificationError @@ -206,7 +205,7 @@ def get_all_cors_headers(self) -> List[str]: def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ): def func(app_info: AppInfo) -> EmailVerificationRecipe: @@ -218,7 +217,7 @@ def func(app_info: AppInfo) -> EmailVerificationRecipe: ingredients, mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) @@ -231,6 +230,15 @@ def callback(): EmailVerificationClaim.validators.is_verified() ) + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + assert EmailVerificationRecipe.__instance is not None + AccountLinkingRecipe.get_instance().register_email_verification_recipe( + EmailVerificationRecipe.__instance + ) + PostSTInitCallbacks.add_post_init_callback(callback) return EmailVerificationRecipe.__instance @@ -260,23 +268,186 @@ def reset(): raise_general_exception("calling testing function in non testing env") EmailVerificationRecipe.__instance = None - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] + async def get_email_for_recipe_user_id( + self, + user: Optional[AccountLinkingUser], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError]: - if self.config.get_email_for_user_id is not None: - res = await self.config.get_email_for_user_id(user_id, user_context) - if not isinstance(res, UnknownUserIdError): - return res + if self.config.get_email_for_recipe_user_id is not None: + user_res = await self.config.get_email_for_recipe_user_id( + recipe_user_id, user_context + ) + if not isinstance(user_res, UnknownUserIdError): + return user_res + + if user is None: + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + user = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context + ) - for f in self.get_email_for_user_id_funcs_from_other_recipes: - res = await f(user_id, user_context) - if not isinstance(res, UnknownUserIdError): - return res + if user is None: + return UnknownUserIdError() + + for login_method in user.login_methods: + if ( + login_method.recipe_user_id.get_as_string() + == recipe_user_id.get_as_string() + ): + if login_method.email is not None: + return GetEmailForUserIdOkResult(email=login_method.email) + else: + return EmailDoesNotExistError() return UnknownUserIdError() - def add_get_email_for_user_id_func(self, f: TypeGetEmailForUserIdFunction): - self.get_email_for_user_id_funcs_from_other_recipes.append(f) + async def get_primary_user_id_for_recipe_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> str: + # We extract this into its own function like this cause we want to make sure that + # this recipe does not get the email of the user ID from the getUser function. + # In fact, there is a test "email verification recipe uses getUser function only in getEmailForRecipeUserId" + # which makes sure that this function is only called in 3 places in this recipe: + # - this function + # - getEmailForRecipeUserId function (above) + # - after verification to get the updated user in verifyEmailUsingToken + # We want to isolate the result of calling this function as much as possible + # so that the consumer of the getUser function does not read the email + # from the primaryUser. Hence, this function only returns the string ID + # and nothing else from the primaryUser. + from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe + + primary_user = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context + ) + ) + if primary_user is None: + # This can come here if the user is using session + email verification + # recipe with a user ID that is not known to supertokens. In this case, + # we do not allow linking for such users. + return recipe_user_id.get_as_string() + return primary_user.id + + async def update_session_if_required_post_email_verification( + self, + req: BaseRequest, + session: Optional[SessionContainer], + recipe_user_id_whose_email_got_verified: RecipeUserId, + user_context: Dict[str, Any], + ) -> Optional[SessionContainer]: + primary_user_id = await self.get_primary_user_id_for_recipe_user( + recipe_user_id_whose_email_got_verified, user_context + ) + + # if a session exists in the API, then we can update the session + # claim related to email verification + if session is not None: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification got session" + ) + # Due to linking, we will have to correct the current + # session's user ID. There are four cases here: + # --> (Case 1) User signed up and did email verification and the new account + # became a primary user (user ID no change) + # --> (Case 2) User signed up and did email verification and the new account got linked + # to another primary user (user ID change) + # --> (Case 3) This is post login account linking, in which the account that got verified + # got linked to the session's account (user ID of account has changed to the session's user ID) + # --> (Case 4) This is post login account linking, in which the account that got verified + # got linked to ANOTHER primary account (user ID of account has changed to a different user ID != session.getUserId, but + # we should ignore this since it will result in the user's session changing.) + + if ( + session.get_recipe_user_id(user_context).get_as_string() + == recipe_user_id_whose_email_got_verified.get_as_string() + ): + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session belongs to the verified user" + ) + # this means that the session's login method's account is the + # one that just got verified and that we are NOT doing post login + # account linking. So this is only for (Case 1) and (Case 2) + + if session.get_user_id() == primary_user_id: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session userId matches the primary user id, so we are only refreshing the claim" + ) + # if the session's primary user ID is equal to the + # primary user ID that the account was linked to, then + # this means that the new account became a primary user (Case 1) + # We also have the sub cases here that the account that just + # got verified was already linked to the session's primary user ID, + # but either way, we don't need to change any user ID. + + # In this case, all we do is to update the emailverification claim + try: + # EmailVerificationClaim will be based on the recipeUserId + # and not the primary user ID. + await session.fetch_and_set_claim( + EmailVerificationClaim, user_context + ) + except Exception as err: + # This should never happen, since we've just set the status above. + if str(err) == "UNKNOWN_USER_ID": + raise_unauthorised_exception("Unknown User ID provided") + raise err + + return None + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session user id doesn't match the primary user id, so we are revoking all sessions and creating a new one" + ) + # if the session's primary user ID is NOT equal to the + # primary user ID that the account that it was linked to, then + # this means that the new account got linked to another primary user (Case 2) + + # In this case, we need to update the session's user ID by creating + # a new session + + # Revoke all session belonging to session.getRecipeUserId() + # We do not really need to do this, but we do it anyway.. no harm. + await revoke_all_sessions_for_user( + recipe_user_id_whose_email_got_verified.get_as_string(), + # False, + session.get_tenant_id(), + user_context, + ) + + # create a new session and return that.. + return await create_new_session( + req, + session.get_tenant_id(), + session.get_recipe_user_id(user_context).get_as_string(), + {}, + {}, + user_context, + ) + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the verified user doesn't match the session" + ) + # this means that the session's login method's account was NOT the + # one that just got verified and that we ARE doing post login + # account linking. So this is only for (Case 3) and (Case 4) + + # In both case 3 and case 4, we do not want to change anything in the + # current session in terms of user ID or email verification claim (since + # both of these refer to the current logged in user and not the newly + # linked user's account). + + return None + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification got no session" + ) + # the session is updated when the is email verification GET API is called + # so we don't do anything in this API. + return None class EmailVerificationClaimValidators(BooleanClaimValidators): @@ -303,14 +474,19 @@ def is_verified( class EmailVerificationClaimClass(BooleanClaim): def __init__(self): async def fetch_value( - user_id: str, _tenant_id: str, user_context: Dict[str, Any] + _: str, + recipe_user_id: RecipeUserId, + __: str, + user_context: Dict[str, Any], ) -> bool: recipe = EmailVerificationRecipe.get_instance_or_throw() - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): return await recipe.recipe_implementation.is_email_verified( - user_id, email_info.email, user_context + recipe_user_id, email_info.email, user_context ) if isinstance(email_info, EmailDoesNotExistError): # we consider people without email addresses as validated @@ -336,25 +512,15 @@ async def email_verify_post( ) -> Union[EmailVerifyPostOkResult, EmailVerifyPostInvalidTokenError]: response = await api_options.recipe_implementation.verify_email_using_token( - token, tenant_id, user_context + token, tenant_id, True, user_context ) if isinstance(response, VerifyEmailUsingTokenOkResult): - if session is not None: - try: - await session.fetch_and_set_claim( - EmailVerificationClaim, user_context - ) - except Exception as e: - # This should never happen since we have just set the status above - if str(e) == "UNKNOWN_USER_ID": - log_debug_message( - "verifyEmailPOST: Returning UNAUTHORISED because the user id provided is unknown" - ) - raise_unauthorised_exception("Unknown User ID provided") - else: - raise e + email_verification_recipe = EmailVerificationRecipe.get_instance_or_throw() + new_session = await email_verification_recipe.update_session_if_required_post_email_verification( + api_options.request, session, response.user.recipe_user_id, user_context + ) - return EmailVerifyPostOkResult(response.user) + return EmailVerifyPostOkResult(response.user, new_session) return EmailVerifyPostInvalidTokenError() async def is_email_verified_get( @@ -363,27 +529,40 @@ async def is_email_verified_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> IsEmailVerifiedGetOkResult: - try: - await session.fetch_and_set_claim(EmailVerificationClaim, user_context) - except Exception as e: - if str(e) == "UNKNOWN_USER_ID": - log_debug_message( - "isEmailVerifiedGET: Returning UNAUTHORISED because the user id provided is unknown" - ) - raise_unauthorised_exception("Unknown User ID provided") - else: - raise e - - is_verified = await session.get_claim_value( - EmailVerificationClaim, user_context + recipe = EmailVerificationRecipe.get_instance_or_throw() + email_info = await recipe.get_email_for_recipe_user_id( + None, session.get_recipe_user_id(user_context), user_context ) - if is_verified is None: - raise Exception( - "Should never come here: EmailVerificationClaim failed to set value" + if isinstance(email_info, GetEmailForUserIdOkResult): + is_verified = await api_options.recipe_implementation.is_email_verified( + session.get_recipe_user_id(user_context), email_info.email, user_context ) - return IsEmailVerifiedGetOkResult(is_verified) + if is_verified: + new_session = ( + await recipe.update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + ) + return IsEmailVerifiedGetOkResult(True, new_session) + else: + await session.set_claim_value( + EmailVerificationClaim, False, user_context + ) + return IsEmailVerifiedGetOkResult(False, None) + elif isinstance(email_info, EmailDoesNotExistError): + # We consider people without email addresses as validated + return IsEmailVerifiedGetOkResult(True, None) + else: + # This means that the user ID is not known to supertokens. This could + # happen if the current session's user ID is not an auth user, + # or if it belongs to a recipe user ID that got deleted. Either way, + # we logout the user. + raise_unauthorised_exception("Unknown User ID provided") async def generate_email_verify_token_post( self, @@ -394,24 +573,30 @@ async def generate_email_verify_token_post( GenerateEmailVerifyTokenPostOkResult, GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError, ]: - user_id = session.get_user_id(user_context) - email_info = ( - await EmailVerificationRecipe.get_instance_or_throw().get_email_for_user_id( - user_id, user_context - ) - ) tenant_id = session.get_tenant_id() + email_info = await EmailVerificationRecipe.get_instance_or_throw().get_email_for_recipe_user_id( + None, session.get_recipe_user_id(user_context), user_context + ) + if isinstance(email_info, EmailDoesNotExistError): log_debug_message( "Email verification email not sent to user %s because it doesn't have an email address.", - user_id, + session.get_recipe_user_id(user_context).get_as_string(), ) - return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError() - if isinstance(email_info, GetEmailForUserIdOkResult): + # This can happen if the user ID was found, but it has no email. In this + # case, we treat it as a success case. + new_session = await EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError(new_session) + elif isinstance(email_info, GetEmailForUserIdOkResult): response = ( await api_options.recipe_implementation.create_email_verification_token( - user_id, + session.get_recipe_user_id(user_context), email_info.email, tenant_id, user_context, @@ -421,21 +606,25 @@ async def generate_email_verify_token_post( if isinstance( response, CreateEmailVerificationTokenEmailAlreadyVerifiedError ): - if await session.get_claim_value(EmailVerificationClaim) is not True: - # this can happen if the email was "verified" in another browser - # and this session is still outdated - and the user has not - # called the get email verification API yet. - await session.fetch_and_set_claim( - EmailVerificationClaim, user_context - ) log_debug_message( - "Email verification email not sent to %s because it is already verified.", - email_info.email, + "Email verification email not sent to user %s because it is already verified.", + session.get_recipe_user_id(user_context).get_as_string(), + ) + new_session = await EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError( + new_session ) - return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError() - if await session.get_claim_value(EmailVerificationClaim) is not False: - # this can happen if the email was "unverified" in another browser + if ( + await session.get_claim_value(EmailVerificationClaim, user_context) + is not False + ): + # This can happen if the email was unverified in another browser # and this session is still outdated - and the user has not # called the get email verification API yet. await session.fetch_and_set_claim(EmailVerificationClaim, user_context) @@ -448,9 +637,15 @@ async def generate_email_verify_token_post( user_context, ) - log_debug_message("Sending email verification email to %s", email_info) + log_debug_message( + "Sending email verification email to %s", email_info.email + ) email_verification_email_delivery_input = VerificationEmailTemplateVars( - user=VerificationEmailTemplateVarsUser(user_id, email_info.email), + user=VerificationEmailTemplateVarsUser( + _id=session.get_user_id(user_context), + recipe_user_id=session.get_recipe_user_id(user_context), + email=email_info.email, + ), email_verify_link=email_verify_link, tenant_id=tenant_id, ) diff --git a/supertokens_python/recipe/emailverification/recipe_implementation.py b/supertokens_python/recipe/emailverification/recipe_implementation.py index c4b17f237..097ecbe54 100644 --- a/supertokens_python/recipe/emailverification/recipe_implementation.py +++ b/supertokens_python/recipe/emailverification/recipe_implementation.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Union, Optional, Callable from supertokens_python.normalised_url_path import NormalisedURLPath @@ -25,28 +25,48 @@ UnverifyEmailOkResult, VerifyEmailUsingTokenInvalidTokenError, VerifyEmailUsingTokenOkResult, + GetEmailForUserIdOkResult, + EmailDoesNotExistError, + UnknownUserIdError, ) -from .types import User +from .types import EmailVerificationUser +from supertokens_python.asyncio import get_user if TYPE_CHECKING: from supertokens_python.querier import Querier - - from .utils import EmailVerificationConfig + from supertokens_python.types import RecipeUserId, AccountLinkingUser class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: EmailVerificationConfig): + def __init__( + self, + querier: Querier, + get_email_for_recipe_user_id: Callable[ + [Optional[AccountLinkingUser], RecipeUserId, Dict[str, Any]], + Awaitable[ + Union[ + GetEmailForUserIdOkResult, + EmailDoesNotExistError, + UnknownUserIdError, + ] + ], + ], + ): super().__init__() self.querier = querier - self.config = config + self.get_email_for_recipe_user_id = get_email_for_recipe_user_id async def create_email_verification_token( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> Union[ CreateEmailVerificationTokenOkResult, CreateEmailVerificationTokenEmailAlreadyVerifiedError, ]: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/user/email/verify/token"), data, @@ -57,7 +77,11 @@ async def create_email_verification_token( return CreateEmailVerificationTokenEmailAlreadyVerifiedError() async def verify_email_using_token( - self, token: str, tenant_id: str, user_context: Dict[str, Any] + self, + token: str, + tenant_id: str, + attempt_account_linking: bool, + user_context: Dict[str, Any], ) -> Union[VerifyEmailUsingTokenOkResult, VerifyEmailUsingTokenInvalidTokenError]: data = {"method": "token", "token": token} response = await self.querier.send_post_request( @@ -65,25 +89,55 @@ async def verify_email_using_token( data, user_context, ) - if "status" in response and response["status"] == "OK": + if response["status"] == "OK": + recipe_user_id = RecipeUserId(response["userId"]) + if attempt_account_linking: + updated_user = await get_user( + recipe_user_id.get_as_string(), user_context + ) + + if updated_user: + # Check if the verified email is currently associated with the user ID + email_info = await self.get_email_for_recipe_user_id( + updated_user, recipe_user_id, user_context + ) + if ( + isinstance(email_info, GetEmailForUserIdOkResult) + and email_info.email == response["email"] + ): + from ..accountlinking.recipe import AccountLinkingRecipe + + account_linking = AccountLinkingRecipe.get_instance() + await account_linking.try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=updated_user, + session=None, + user_context=user_context, + ) + return VerifyEmailUsingTokenOkResult( - User(response["userId"], response["email"]) + EmailVerificationUser(recipe_user_id, response["email"]) ) - return VerifyEmailUsingTokenInvalidTokenError() + else: + return VerifyEmailUsingTokenInvalidTokenError() async def is_email_verified( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> bool: - params = {"userId": user_id, "email": email} + params = {"userId": recipe_user_id.get_as_string(), "email": email} response = await self.querier.send_get_request( NormalisedURLPath("/recipe/user/email/verify"), params, user_context ) return response["isVerified"] async def revoke_email_verification_tokens( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> RevokeEmailVerificationTokensOkResult: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/user/email/verify/token/remove"), data, @@ -92,9 +146,9 @@ async def revoke_email_verification_tokens( return RevokeEmailVerificationTokensOkResult() async def unverify_email( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> UnverifyEmailOkResult: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} await self.querier.send_post_request( NormalisedURLPath("/recipe/user/email/verify/remove"), data, user_context ) diff --git a/supertokens_python/recipe/emailverification/syncio/__init__.py b/supertokens_python/recipe/emailverification/syncio/__init__.py index 9621cbec9..914f86498 100644 --- a/supertokens_python/recipe/emailverification/syncio/__init__.py +++ b/supertokens_python/recipe/emailverification/syncio/__init__.py @@ -16,11 +16,12 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.recipe.emailverification.types import EmailTemplateVars +from supertokens_python.types import RecipeUserId def create_email_verification_token( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): @@ -29,35 +30,40 @@ def create_email_verification_token( ) return sync( - create_email_verification_token(tenant_id, user_id, email, user_context) + create_email_verification_token(tenant_id, recipe_user_id, email, user_context) ) def verify_email_using_token( tenant_id: str, token: str, + attempt_account_linking: bool = True, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import ( verify_email_using_token, ) - return sync(verify_email_using_token(tenant_id, token, user_context)) + return sync( + verify_email_using_token( + tenant_id, token, attempt_account_linking, user_context + ) + ) def is_email_verified( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import is_email_verified - return sync(is_email_verified(user_id, email, user_context)) + return sync(is_email_verified(recipe_user_id, email, user_context)) def revoke_email_verification_tokens( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ): @@ -66,18 +72,18 @@ def revoke_email_verification_tokens( ) return sync( - revoke_email_verification_tokens(tenant_id, user_id, email, user_context) + revoke_email_verification_tokens(tenant_id, recipe_user_id, email, user_context) ) def unverify_email( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import unverify_email - return sync(unverify_email(user_id, email, user_context)) + return sync(unverify_email(recipe_user_id, email, user_context)) def send_email( @@ -91,7 +97,7 @@ def send_email( def create_email_verification_link( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ): @@ -99,12 +105,15 @@ def create_email_verification_link( create_email_verification_link, ) - return sync(create_email_verification_link(tenant_id, user_id, email, user_context)) + return sync( + create_email_verification_link(tenant_id, recipe_user_id, email, user_context) + ) def send_email_verification_email( tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ): @@ -112,4 +121,8 @@ def send_email_verification_email( send_email_verification_email, ) - return sync(send_email_verification_email(tenant_id, user_id, email, user_context)) + return sync( + send_email_verification_email( + tenant_id, user_id, recipe_user_id, email, user_context + ) + ) diff --git a/supertokens_python/recipe/emailverification/types.py b/supertokens_python/recipe/emailverification/types.py index c32b46ddf..118c7e70b 100644 --- a/supertokens_python/recipe/emailverification/types.py +++ b/supertokens_python/recipe/emailverification/types.py @@ -20,17 +20,19 @@ SMTPServiceInterface, EmailDeliveryInterface, ) +from supertokens_python.types import RecipeUserId -class User: - def __init__(self, user_id: str, email: str): - self.user_id = user_id +class EmailVerificationUser: + def __init__(self, recipe_user_id: RecipeUserId, email: str): + self.recipe_user_id = recipe_user_id self.email = email class VerificationEmailTemplateVarsUser: - def __init__(self, user_id: str, email: str): - self.id = user_id + def __init__(self, _id: str, recipe_user_id: RecipeUserId, email: str): + self.id = _id + self.recipe_user_id = recipe_user_id self.email = email diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 7dcebb515..71d207e8c 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -55,20 +55,20 @@ def __init__( get_email_delivery_config: Callable[ [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] ], - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction], + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction], override: OverrideConfig, ): self.mode = mode self.override = override self.get_email_delivery_config = get_email_delivery_config - self.get_email_for_user_id = get_email_for_user_id + self.get_email_for_recipe_user_id = get_email_for_recipe_user_id def validate_and_normalise_user_input( app_info: AppInfo, mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> EmailVerificationConfig: if mode not in ["REQUIRED", "OPTIONAL"]: @@ -98,7 +98,7 @@ def get_email_delivery_config() -> ( return EmailVerificationConfig( mode, get_email_delivery_config, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 0a6f77764..0a83110e5 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -42,7 +42,7 @@ ContactPhoneOnlyConfig, ) from supertokens_python.recipe.session.asyncio import create_new_session -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, RecipeUserId from ...emailverification import EmailVerificationRecipe from ...emailverification.interfaces import CreateEmailVerificationTokenOkResult @@ -289,12 +289,12 @@ async def consume_code_post( ev_instance = EmailVerificationRecipe.get_instance_optional() if ev_instance is not None: token_response = await ev_instance.recipe_implementation.create_email_verification_token( - user.user_id, user.email, tenant_id, user_context + RecipeUserId(user.user_id), user.email, tenant_id, user_context ) if isinstance(token_response, CreateEmailVerificationTokenOkResult): await ev_instance.recipe_implementation.verify_email_using_token( - token_response.token, tenant_id, user_context + token_response.token, tenant_id, True, user_context ) session = await create_new_session( diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index 0367710f3..93b7ef3dd 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -56,7 +56,6 @@ OverrideConfig, validate_and_normalise_user_input, ) -from ..emailverification import EmailVerificationRecipe from ..emailverification.interfaces import ( GetEmailForUserIdOkResult, EmailDoesNotExistError, @@ -142,9 +141,7 @@ def __init__( ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) + pass PostSTInitCallbacks.add_post_init_callback(callback) diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index a50422885..a765b2d90 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -28,7 +28,7 @@ SessionInformationResult, ) from supertokens_python.recipe.session.recipe import SessionRecipe -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import FRAMEWORKS, resolve from ...jwt.interfaces import ( @@ -112,7 +112,9 @@ async def create_new_session_without_request_response( del final_access_token_payload[prop] for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, tenant_id, user_context) + update = await claim.build( + user_id, RecipeUserId(user_id), tenant_id, user_context + ) final_access_token_payload = {**final_access_token_payload, **update} return await SessionRecipe.get_instance().recipe_implementation.create_new_session( diff --git a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py index aec6c6f71..95781b406 100644 --- a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py @@ -13,7 +13,7 @@ # under the License. from typing import Any, Callable, Dict, Optional -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from .primitive_claim import PrimitiveClaim, PrimitiveClaimValidators @@ -31,7 +31,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any]], MaybeAwaitable[Optional[bool]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py index b3a8ad08d..e70d2663b 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms from ..interfaces import ( @@ -132,7 +132,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any]], MaybeAwaitable[Optional[Primitive]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 63bce71c6..61134232c 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -632,7 +632,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any]], MaybeAwaitable[Optional[_T]], ], ) -> None: @@ -677,13 +677,16 @@ def get_value_from_payload( async def build( self, user_id: str, + recipe_user_id: RecipeUserId, tenant_id: str, user_context: Optional[Dict[str, Any]] = None, ) -> JSONObject: if user_context is None: user_context = {} - value = await resolve(self.fetch_value(user_id, tenant_id, user_context)) + value = await resolve( + self.fetch_value(user_id, recipe_user_id, tenant_id, user_context) + ) if value is None: return {} diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index ec02f8bbf..1c09a1fc0 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -20,7 +20,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.utils import resolve -from ...types import MaybeAwaitable +from ...types import MaybeAwaitable, RecipeUserId from . import session_functions from .access_token import validate_access_token_structure from .cookie_and_header import build_front_token @@ -128,6 +128,7 @@ async def validate_claims( value = await resolve( validator.claim.fetch_value( user_id, + RecipeUserId(user_id), access_token_payload.get("tId", DEFAULT_TENANT_ID), user_context, ) @@ -419,7 +420,10 @@ async def fetch_and_set_claim( return False access_token_payload_update = await claim.build( - session_info.user_id, session_info.tenant_id, user_context + session_info.user_id, + RecipeUserId(session_info.user_id), + session_info.tenant_id, + user_context, ) return await self.merge_into_access_token_payload( session_handle, access_token_payload_update, user_context diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index dba63b3c0..1a75b4183 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -163,9 +163,9 @@ def get_all_session_tokens_dangerously(self) -> GetSessionTokensDangerouslyDict: return { "accessToken": self.access_token, "accessAndFrontTokenUpdated": self.access_token_updated, - "refreshToken": None - if self.refresh_token is None - else self.refresh_token.token, + "refreshToken": ( + None if self.refresh_token is None else self.refresh_token.token + ), "frontToken": self.front_token, "antiCsrfToken": self.anti_csrf_token, } @@ -236,7 +236,10 @@ async def fetch_and_set_claim( user_context = {} update = await claim.build( - self.get_user_id(), self.get_tenant_id(), user_context + self.get_user_id(), + self.get_recipe_user_id(), + self.get_tenant_id(), + user_context, ) return await self.merge_into_access_token_payload(update, user_context) diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 2fd8fcfa9..0e211627a 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -54,7 +54,7 @@ get_required_claim_validators, get_auth_mode_from_header, ) -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import ( FRAMEWORKS, get_rid_from_header, @@ -268,7 +268,9 @@ async def create_new_session_in_request( del final_access_token_payload[prop] for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, tenant_id, user_context) + update = await claim.build( + user_id, RecipeUserId(user_id), tenant_id, user_context + ) final_access_token_payload.update(update) log_debug_message("createNewSession: Access token payload built") diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index c9daa33f0..97974c883 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -36,7 +36,7 @@ from supertokens_python.recipe.thirdparty.interfaces import APIOptions from supertokens_python.recipe.thirdparty.provider import Provider -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, RecipeUserId class APIImplementation(APIInterface): @@ -118,7 +118,7 @@ async def sign_in_up_post( if ev_instance is not None: token_response = await ev_instance.recipe_implementation.create_email_verification_token( tenant_id=tenant_id, - user_id=signinup_response.user.user_id, + recipe_user_id=RecipeUserId(signinup_response.user.user_id), email=signinup_response.user.email, user_context=user_context, ) @@ -127,6 +127,7 @@ async def sign_in_up_post( await ev_instance.recipe_implementation.verify_email_using_token( token=token_response.token, tenant_id=tenant_id, + attempt_account_linking=True, user_context=user_context, ) diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index 5c234c637..f9f50af0d 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -33,7 +33,6 @@ from .utils import SignInAndUpFeature, InputOverrideConfig from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from .api import ( @@ -81,10 +80,6 @@ def __init__( ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) - mt_recipe = MultitenancyRecipe.get_instance_optional() if mt_recipe: mt_recipe.static_third_party_providers = self.providers diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index 74cc3fd7f..81c704a14 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -85,7 +85,7 @@ async def set_role_api(request: HttpRequest): @verify_session() async def unverify_email_api(request: HttpRequest): session_: SessionContainer = request.supertokens # type: ignore - await unverify_email(session_.get_user_id()) + await unverify_email(session_.get_recipe_user_id()) await session_.fetch_and_set_claim(EmailVerificationClaim) return JsonResponse({"status": "OK"}) @@ -139,7 +139,7 @@ def sync_set_role_api(request: HttpRequest): @verify_session() def sync_unverify_email_api(request: HttpRequest): session_: SessionContainer = request.supertokens # type: ignore - sync_unverify_email(session_.get_user_id()) + sync_unverify_email(session_.get_recipe_user_id()) session_.sync_fetch_and_set_claim(EmailVerificationClaim) return JsonResponse({"status": "OK"}) diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 67fdcd407..fcaab51ba 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -780,7 +780,7 @@ async def get_token(): @app.get("/unverifyEmail") async def unverify_email_api(session_: SessionContainer = Depends(verify_session())): - await unverify_email(session_.get_user_id()) + await unverify_email(session_.get_recipe_user_id()) await session_.fetch_and_set_claim(EmailVerificationClaim) return JSONResponse({"status": "OK"}) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 07336dc8f..6e1618e30 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -796,7 +796,7 @@ def test_feature_flags(): @verify_session() def unverify_email_api(): session_: SessionContainer = g.supertokens # type: ignore - unverify_email(session_.get_user_id()) + unverify_email(session_.get_recipe_user_id()) session_.sync_fetch_and_set_claim(EmailVerificationClaim) return jsonify({"status": "OK"}) diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index bda081c8d..34c316032 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -19,6 +19,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -44,7 +45,7 @@ VerifyEmailUsingTokenInvalidTokenError, ) from supertokens_python.recipe.emailverification.types import ( - User as EVUser, + EmailVerificationUser as EVUser, ) from supertokens_python.recipe.emailverification.utils import OverrideConfig from supertokens_python.recipe.session import SessionContainer @@ -1372,13 +1373,13 @@ def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: email = dict_response["user"]["email"] await send_email_verification_email( - "public", user_id, email, {"url": "localhost:3000"} + "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3000"} ) url = urlparse(email_verify_link) assert url.netloc == "localhost:3000" await send_email_verification_email( - "public", user_id, email, {"url": "localhost:3002"} + "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3002"} ) url = urlparse(email_verify_link) assert url.netloc == "localhost:3002" diff --git a/tests/input_validation/test_input_validation.py b/tests/input_validation/test_input_validation.py index 96beec0ea..aa01bca5f 100644 --- a/tests/input_validation/test_input_validation.py +++ b/tests/input_validation/test_input_validation.py @@ -16,6 +16,7 @@ from supertokens_python.recipe.emailverification.interfaces import ( GetEmailForUserIdOkResult, ) +from supertokens_python.types import RecipeUserId @pytest.mark.asyncio @@ -84,7 +85,7 @@ async def test_init_validation_emailpassword(): assert "override must be of type InputOverrideConfig or None" == str(ex.value) -async def get_email_for_user_id(_: str, __: Dict[str, Any]): +async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): return GetEmailForUserIdOkResult("foo@example.com") @@ -120,7 +121,7 @@ async def test_init_validation_emailverification(): recipe_list=[ emailverification.init( mode="OPTIONAL", - get_email_for_user_id=get_email_for_user_id, + get_email_for_recipe_user_id=get_email_for_user_id, override="override", # type: ignore ) ], diff --git a/tests/passwordless/test_emaildelivery.py b/tests/passwordless/test_emaildelivery.py index ea8fa26ec..41d9418b9 100644 --- a/tests/passwordless/test_emaildelivery.py +++ b/tests/passwordless/test_emaildelivery.py @@ -20,6 +20,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -183,7 +184,7 @@ async def send_email_override( pless_response = await signinup("public", "test@example.com", None, {}) create_token = await create_email_verification_token( - "public", pless_response.user.user_id + "public", RecipeUserId(pless_response.user.user_id) ) assert isinstance(create_token, CreateEmailVerificationTokenOkResult) diff --git a/tests/sessions/claims/test_assert_claims.py b/tests/sessions/claims/test_assert_claims.py index f1f303884..bb7cde83f 100644 --- a/tests/sessions/claims/test_assert_claims.py +++ b/tests/sessions/claims/test_assert_claims.py @@ -120,7 +120,7 @@ async def validate( def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]): return False - dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___: "Hello world") + dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___, ____: "Hello world") dummy_claim_validator = DummyClaimValidator(dummy_claim) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index ac1469b56..47bd35b06 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -5,6 +5,7 @@ from pytest import fixture, mark from pytest_mock import MockerFixture from supertokens_python.recipe.session.claims import PrimitiveArrayClaim +from supertokens_python.types import RecipeUserId from supertokens_python.utils import get_timestamp_ms, resolve from tests.utils import AsyncMock from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID @@ -58,28 +59,28 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]): async def test_primitive_claim(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveArrayClaim("key", async_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveArrayClaim("key", sync_fetch_value) user_id, ctx = "user_id", {} - await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ 0 @@ -92,7 +93,7 @@ async def test_primitive_claim_fetch_value_none(): claim = PrimitiveArrayClaim("key", fetch_value_none) user_id, ctx = "user_id", {} - res = await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) assert res == {} @@ -120,7 +121,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) assert claim.get_last_refetch_time(payload) == timestamp @@ -142,7 +143,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.includes(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -155,7 +156,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.includes(included_item).validate(payload, {}) assert res.is_valid is True @@ -163,7 +164,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_not_validate_old_values(patch_get_timestamp_ms: MagicMock): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -181,7 +182,7 @@ async def test_should_validate_old_values_if_max_age_is_none_and_default_is_inf( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -199,7 +200,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) assert ( await resolve( claim.validators.includes(excluded_item, 600).should_refetch(payload, {}) @@ -212,7 +213,7 @@ async def test_validator_should_refetch_if_value_is_old( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -229,7 +230,7 @@ async def test_validator_should_not_refetch_if_max_age_is_none_and_default_is_in patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -246,7 +247,7 @@ async def test_validator_should_validate_values_with_default_max_age( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 10 MINS: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -259,7 +260,7 @@ async def test_validator_should_not_refetch_if_max_age_overrides_to_inf( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -292,7 +293,7 @@ async def test_validator_excludes_should_not_validate_empty_payload(): async def test_validator_excludes_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.excludes(included_item).validate(payload, {}) assert res.is_valid is False @@ -305,7 +306,7 @@ async def test_validator_excludes_should_not_validate_mismatching_payload(): async def test_validator_excludes_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.excludes(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -328,7 +329,7 @@ async def test_validator_includes_all_should_not_validate_empty_payload(): async def test_validator_includes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.includes_all(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -341,7 +342,7 @@ async def test_validator_includes_all_should_not_validate_mismatching_payload(): async def test_validator_includes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.includes_all(included_item).validate(payload, {}) assert res.is_valid is True @@ -364,7 +365,7 @@ async def test_validator_excludes_all_should_not_validate_empty_payload(): async def test_validator_excludes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.excludes_all(included_item).validate(payload, {}) assert res.is_valid is False @@ -377,7 +378,7 @@ async def test_validator_excludes_all_should_not_validate_mismatching_payload(): async def test_validator_excludes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.excludes_all(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -387,7 +388,7 @@ async def test_validator_should_not_validate_older_values_with_5min_default_max_ patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 10 MINS: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 2ee8d98ba..515716e06 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -2,6 +2,7 @@ from pytest import mark from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.types import RecipeUserId from supertokens_python.utils import resolve from tests.utils import AsyncMock @@ -25,28 +26,28 @@ def teardown_function(_): async def test_primitive_claim(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveClaim("key", async_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveClaim("key", sync_fetch_value) ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveClaim("key", sync_fetch_value) user_id, ctx = "user_id", {} - await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ 0 @@ -59,7 +60,7 @@ async def test_primitive_claim_fetch_value_none(): claim = PrimitiveClaim("key", fetch_value_none) user_id, ctx = "user_id", {} - res = await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) assert res == {} @@ -89,7 +90,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) assert claim.get_last_refetch_time(payload) == timestamp @@ -111,7 +112,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.has_value(val2).validate(payload, {}) assert res.is_valid is False @@ -124,7 +125,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.has_value(val).validate(payload, {}) assert res.is_valid is True @@ -132,7 +133,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -150,7 +151,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) assert ( await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) is False @@ -174,7 +175,7 @@ async def test_should_not_validate_empty_payload(): async def test_has_fresh_value_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.has_value(val2, 600).validate(payload, {}) assert res.is_valid is False assert res.reason == { @@ -186,7 +187,7 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload(): async def test_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) res = await claim.validators.has_value(val, 600).validate(payload, {}) assert res.is_valid is True @@ -196,7 +197,7 @@ async def test_should_not_validate_old_values_as_well( ): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -213,14 +214,14 @@ async def test_should_refetch_if_value_is_not_set(): async def test_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public") assert claim.validators.has_value(val2, 600).should_refetch(payload, {}) is False async def test_should_refetch_if_value_is_old(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", DEFAULT_TENANT_ID) + payload = await claim.build("userId", RecipeUserId("userId"), DEFAULT_TENANT_ID) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -232,7 +233,7 @@ async def test_should_not_validate_old_values_as_well_with_default_max_age_provi patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -250,7 +251,7 @@ async def test_should_refetch_if_value_is_old_with_default_max_age_provided( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("userId", "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public") # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index fd6ae7d65..529cd05a8 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -6,6 +6,7 @@ SessionClaim, ) from supertokens_python.recipe.session.interfaces import RecipeInterface +from supertokens_python.types import RecipeUserId from tests.utils import st_init_common_args TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore @@ -29,7 +30,9 @@ async def new_create_new_session( tenant_id: str, user_context: Dict[str, Any], ): - payload_update = await claim.build(user_id, tenant_id, user_context) + payload_update = await claim.build( + user_id, RecipeUserId(user_id), tenant_id, user_context + ) if access_token_payload is None: access_token_payload = {} access_token_payload = { From e9ab8618666e8fbc673da5cd77b06c38b8cb3900 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 9 Aug 2024 15:45:22 +0530 Subject: [PATCH 014/126] makes changes for session recipe --- CHANGELOG.md | 2 + .../api/userdetails/user_sessions_get.py | 2 +- .../emailpassword/api/implementation.py | 6 +- .../recipe/emailverification/recipe.py | 4 +- .../recipe/passwordless/api/implementation.py | 2 +- .../recipe/session/access_token.py | 47 ++++++++-- .../recipe/session/api/implementation.py | 5 +- .../recipe/session/api/signout.py | 2 + .../recipe/session/asyncio/__init__.py | 85 +++++++------------ .../recipe/session/exceptions.py | 10 ++- .../recipe/session/interfaces.py | 27 +++--- supertokens_python/recipe/session/recipe.py | 2 +- .../recipe/session/recipe_implementation.py | 45 +++++----- .../recipe/session/session_class.py | 7 +- .../recipe/session/session_functions.py | 52 ++++++++++-- .../session/session_request_functions.py | 6 +- .../recipe/session/syncio/__init__.py | 57 +++++-------- supertokens_python/recipe/session/utils.py | 18 ++-- .../recipe/thirdparty/api/implementation.py | 2 +- tests/Django/test_django.py | 5 +- tests/Fastapi/test_fastapi.py | 7 +- tests/Flask/test_flask.py | 11 ++- tests/auth-react/django3x/mysite/utils.py | 2 +- tests/auth-react/fastapi-server/app.py | 2 +- tests/auth-react/flask-server/app.py | 2 +- tests/emailpassword/test_emailexists.py | 3 +- tests/emailpassword/test_emailverify.py | 2 +- tests/emailpassword/test_passwordreset.py | 3 +- tests/emailpassword/test_signin.py | 3 +- .../django2x/polls/views.py | 3 + .../django3x/polls/views.py | 3 + .../drf_async/mysite/settings.py | 4 + .../drf_async/polls/views.py | 3 + .../drf_sync/mysite/settings.py | 4 + .../drf_sync/polls/views.py | 3 + .../frontendIntegration/fastapi-server/app.py | 3 + tests/frontendIntegration/flask-server/app.py | 5 +- tests/jwt/test_get_JWKS.py | 3 +- tests/sessions/claims/test_assert_claims.py | 7 +- .../claims/test_create_new_session.py | 7 +- .../claims/test_fetch_and_set_claim.py | 3 + tests/sessions/claims/test_get_claim_value.py | 7 +- tests/sessions/claims/test_remove_claim.py | 10 ++- tests/sessions/claims/test_set_claim_value.py | 6 +- ...test_validate_claims_for_session_handle.py | 3 +- tests/sessions/claims/test_verify_session.py | 13 ++- tests/sessions/claims/utils.py | 2 + tests/sessions/test_access_token_version.py | 13 ++- tests/sessions/test_auth_mode.py | 5 +- tests/sessions/test_jwks.py | 35 ++++++-- tests/sessions/test_session_error_handlers.py | 9 +- .../test_use_dynamic_signing_key_switching.py | 5 +- tests/test-server/session.py | 2 + tests/test_config.py | 9 +- tests/test_session.py | 21 +++-- tests/test_user_context.py | 7 ++ tests/thirdparty/test_emaildelivery.py | 11 +-- tests/userroles/test_claims.py | 15 ++-- 58 files changed, 404 insertions(+), 238 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d6605b1c..23944210e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `supertokens_python.recipe.emailverification.types.User` has been renamed to `supertokens_python.recipe.emailverification.types.EmailVerificationUser` - The user object has been changed to be a global one, containing information about all emails, phone numbers, third party info and login methods associated with that user. - Type of `get_email_for_user_id` in `emailverification.init` has changed +- Session recipe's error handlers take an extra param of recipe_user_id as well +- Session recipe, removes `validate_claims_in_jwt_payload` that is exposed to the user. ## [0.24.0] - 2024-07-31 diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py index 40b4e82b9..5dc815efd 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py @@ -29,7 +29,7 @@ async def handle_sessions_get( # Passing tenant id as None sets fetch_across_all_tenants to True # which is what we want here. session_handles = await get_all_session_handles_for_user( - user_id, None, user_context + user_id, True, None, user_context ) sessions: List[Optional[SessionInfo]] = [None for _ in session_handles] diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index 54f90fd04..0e0953db3 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -47,7 +47,7 @@ if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import APIOptions -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, RecipeUserId class APIImplementation(APIInterface): @@ -178,7 +178,7 @@ async def sign_in_post( session = await create_new_session( tenant_id=tenant_id, request=api_options.request, - user_id=user.user_id, + recipe_user_id=RecipeUserId(user.user_id), access_token_payload={}, session_data_in_database={}, user_context=user_context, @@ -219,7 +219,7 @@ async def sign_up_post( session = await create_new_session( tenant_id=tenant_id, request=api_options.request, - user_id=user.user_id, + recipe_user_id=RecipeUserId(user.user_id), access_token_payload={}, session_data_in_database={}, user_context=user_context, diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 670adb84f..d257026c4 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -413,7 +413,7 @@ async def update_session_if_required_post_email_verification( # We do not really need to do this, but we do it anyway.. no harm. await revoke_all_sessions_for_user( recipe_user_id_whose_email_got_verified.get_as_string(), - # False, + False, session.get_tenant_id(), user_context, ) @@ -422,7 +422,7 @@ async def update_session_if_required_post_email_verification( return await create_new_session( req, session.get_tenant_id(), - session.get_recipe_user_id(user_context).get_as_string(), + session.get_recipe_user_id(user_context), {}, {}, user_context, diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 0a83110e5..840da912a 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -300,7 +300,7 @@ async def consume_code_post( session = await create_new_session( request=api_options.request, tenant_id=tenant_id, - user_id=user.user_id, + recipe_user_id=RecipeUserId(user.user_id), access_token_payload={}, session_data_in_database={}, user_context=user_context, diff --git a/supertokens_python/recipe/session/access_token.py b/supertokens_python/recipe/session/access_token.py index ac9ea51e7..ce4eec245 100644 --- a/supertokens_python/recipe/session/access_token.py +++ b/supertokens_python/recipe/session/access_token.py @@ -101,6 +101,7 @@ def get_info_from_access_token( user_data = payload session_handle = sanitize_string(payload.get("sessionHandle")) + recipe_user_id = sanitize_string(payload.get("recipeUserId", user_id)) refresh_token_hash_1 = sanitize_string(payload.get("refreshTokenHash1")) parent_refresh_token_hash_1 = sanitize_string( payload.get("parentRefreshTokenHash1") @@ -129,6 +130,7 @@ def get_info_from_access_token( "expiryTime": expiry_time, "timeCreated": time_created, "tenantId": tenant_id, + "recipeUserId": recipe_user_id, } except Exception as e: log_debug_message( @@ -139,7 +141,40 @@ def get_info_from_access_token( def validate_access_token_structure(payload: Dict[str, Any], version: int) -> None: - if version >= 3: + if version >= 5: + if ( + not isinstance(payload.get("sub"), str) + or not isinstance(payload.get("exp"), (int, float)) + or not isinstance(payload.get("iat"), (int, float)) + or not isinstance(payload.get("sessionHandle"), str) + or not isinstance(payload.get("refreshTokenHash1"), str) + or not isinstance(payload.get("rsub"), str) + ): + log_debug_message( + "validateAccessTokenStructure: Access token is using version >= 5" + ) + # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. + raise Exception( + "Access token does not contain all the information. Maybe the structure has changed?" + ) + elif version >= 4: + if ( + not isinstance(payload.get("sub"), str) + or not isinstance(payload.get("exp"), (int, float)) + or not isinstance(payload.get("iat"), (int, float)) + or not isinstance(payload.get("sessionHandle"), str) + or not isinstance(payload.get("refreshTokenHash1"), str) + ): + log_debug_message( + "validateAccessTokenStructure: Access token is using version >= 4" + ) + # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. + raise Exception( + "Access token does not contain all the information. Maybe the structure has changed?" + ) + elif version >= 3: if ( not isinstance(payload.get("sub"), str) or not isinstance(payload.get("exp"), (int, float)) @@ -151,6 +186,7 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No "validateAccessTokenStructure: Access token is using version >= 3" ) # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) @@ -160,18 +196,19 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) - elif ( not isinstance(payload.get("sessionHandle"), str) - or payload.get("userData") is None + or not isinstance(payload.get("userId"), str) or not isinstance(payload.get("refreshTokenHash1"), str) - or not isinstance(payload.get("expiryTime"), (float, int)) - or not isinstance(payload.get("timeCreated"), (float, int)) + or payload.get("userData") is None + or not isinstance(payload.get("expiryTime"), (int, float)) + or not isinstance(payload.get("timeCreated"), (int, float)) ): log_debug_message( "validateAccessTokenStructure: Access token is using version < 3" ) # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) diff --git a/supertokens_python/recipe/session/api/implementation.py b/supertokens_python/recipe/session/api/implementation.py index 1521243e8..2229f6ead 100644 --- a/supertokens_python/recipe/session/api/implementation.py +++ b/supertokens_python/recipe/session/api/implementation.py @@ -49,12 +49,11 @@ async def refresh_post( async def signout_post( self, - session: Optional[SessionContainer], + session: SessionContainer, api_options: APIOptions, user_context: Dict[str, Any], ) -> SignOutOkayResponse: - if session is not None: - await session.revoke_session(user_context) + await session.revoke_session(user_context) return SignOutOkayResponse() async def verify_session( diff --git a/supertokens_python/recipe/session/api/signout.py b/supertokens_python/recipe/session/api/signout.py index 5fcf69c07..9c0c68be0 100644 --- a/supertokens_python/recipe/session/api/signout.py +++ b/supertokens_python/recipe/session/api/signout.py @@ -48,6 +48,8 @@ async def handle_signout_api( user_context=user_context, ) + assert session is not None + response = await api_implementation.signout_post(session, api_options, user_context) if api_options.response is None: raise Exception("Should never come here") diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index a765b2d90..346d111ef 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -20,7 +20,6 @@ from supertokens_python.recipe.session.interfaces import ( ClaimsValidationResult, GetClaimValueOkResult, - JSONObject, SessionClaim, SessionClaimValidator, SessionContainer, @@ -45,6 +44,7 @@ from ..utils import get_required_claim_validators from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.asyncio import get_user _T = TypeVar("_T") @@ -52,7 +52,7 @@ async def create_new_session( request: Any, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -68,12 +68,18 @@ async def create_new_session( config = recipe_instance.config app_info = recipe_instance.app_info + user = await get_user(recipe_user_id.get_as_string(), user_context) + user_id = recipe_user_id.get_as_string() + if user is not None: + user_id = user.id + return await create_new_session_in_request( request, user_context, recipe_instance, access_token_payload, user_id, + recipe_user_id, config, app_info, session_data_in_database, @@ -83,7 +89,7 @@ async def create_new_session( async def create_new_session_without_request_response( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, disable_anti_csrf: bool = False, @@ -111,14 +117,18 @@ async def create_new_session_without_request_response( if prop in final_access_token_payload: del final_access_token_payload[prop] + user = await get_user(recipe_user_id.get_as_string(), user_context) + user_id = recipe_user_id.get_as_string() + if user is not None: + user_id = user.id + for claim in claims_added_by_other_recipes: - update = await claim.build( - user_id, RecipeUserId(user_id), tenant_id, user_context - ) + update = await claim.build(user_id, recipe_user_id, tenant_id, user_context) final_access_token_payload = {**final_access_token_payload, **update} return await SessionRecipe.get_instance().recipe_implementation.create_new_session( user_id, + recipe_user_id, final_access_token_payload, session_data_in_database, disable_anti_csrf, @@ -159,6 +169,7 @@ async def validate_claims_for_session_handle( recipe_impl.get_global_claim_validators( session_info.tenant_id, session_info.user_id, + session_info.recipe_user_id, claim_validators_added_by_other_recipes, user_context, ) @@ -175,6 +186,7 @@ async def validate_claims_for_session_handle( claim_validation_res = await recipe_impl.validate_claims( session_info.user_id, + session_info.recipe_user_id, session_info.custom_claims_in_access_token_payload, claim_validators, user_context, @@ -192,53 +204,6 @@ async def validate_claims_for_session_handle( return ClaimsValidationResult(claim_validation_res.invalid_claims) -async def validate_claims_in_jwt_payload( - tenant_id: str, - user_id: str, - jwt_payload: JSONObject, - override_global_claim_validators: Optional[ - Callable[ - [ - List[SessionClaimValidator], - str, - Dict[str, Any], - ], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - - recipe_impl = SessionRecipe.get_instance().recipe_implementation - - claim_validators_added_by_other_recipes = ( - SessionRecipe.get_instance().get_claim_validators_added_by_other_recipes() - ) - global_claim_validators = await resolve( - recipe_impl.get_global_claim_validators( - tenant_id, - user_id, - claim_validators_added_by_other_recipes, - user_context, - ) - ) - - if override_global_claim_validators is not None: - claim_validators = await resolve( - override_global_claim_validators( - global_claim_validators, user_id, user_context - ) - ) - else: - claim_validators = global_claim_validators - - return await recipe_impl.validate_claims_in_jwt_payload( - user_id, jwt_payload, claim_validators, user_context - ) - - async def fetch_and_set_claim( session_handle: str, claim: SessionClaim[Any], @@ -437,25 +402,35 @@ async def revoke_session( async def revoke_all_sessions_for_user( user_id: str, + revoke_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: if user_context is None: user_context = {} return await SessionRecipe.get_instance().recipe_implementation.revoke_all_sessions_for_user( - user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context + user_id, + revoke_sessions_for_linked_accounts, + tenant_id or DEFAULT_TENANT_ID, + tenant_id is None, + user_context, ) async def get_all_session_handles_for_user( user_id: str, + fetch_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: if user_context is None: user_context = {} return await SessionRecipe.get_instance().recipe_implementation.get_all_session_handles_for_user( - user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context + user_id, + fetch_sessions_for_linked_accounts, + tenant_id or DEFAULT_TENANT_ID, + tenant_id is None, + user_context, ) diff --git a/supertokens_python/recipe/session/exceptions.py b/supertokens_python/recipe/session/exceptions.py index 9c0d1552e..637cf486c 100644 --- a/supertokens_python/recipe/session/exceptions.py +++ b/supertokens_python/recipe/session/exceptions.py @@ -16,13 +16,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union from supertokens_python.exceptions import SuperTokensError +from supertokens_python.types import RecipeUserId if TYPE_CHECKING: from .interfaces import ResponseMutator -def raise_token_theft_exception(user_id: str, session_handle: str) -> NoReturn: - raise TokenTheftError(user_id, session_handle) +def raise_token_theft_exception( + user_id: str, recipe_user_id: RecipeUserId, session_handle: str +) -> NoReturn: + raise TokenTheftError(user_id, recipe_user_id, session_handle) def raise_try_refresh_token_exception(ex: Union[str, Exception]) -> NoReturn: @@ -60,9 +63,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class TokenTheftError(SuperTokensSessionError): - def __init__(self, user_id: str, session_handle: str): + def __init__(self, user_id: str, recipe_user_id: RecipeUserId, session_handle: str): super().__init__("token theft detected") self.user_id = user_id + self.recipe_user_id = recipe_user_id self.session_handle = session_handle diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 61134232c..0f9b6f19c 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -50,6 +50,7 @@ def __init__( self, handle: str, user_id: str, + recipe_user_id: RecipeUserId, user_data_in_jwt: Dict[str, Any], tenant_id: str, ): @@ -57,6 +58,7 @@ def __init__( self.user_id = user_id self.user_data_in_jwt = user_data_in_jwt self.tenant_id = tenant_id + self.recipe_user_id = recipe_user_id class AccessTokenObj: @@ -77,6 +79,7 @@ def __init__( self, session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, session_data_in_database: Dict[str, Any], expiry: int, custom_claims_in_access_token_payload: Dict[str, Any], @@ -92,6 +95,7 @@ def __init__( ) self.time_created = time_created self.tenant_id = tenant_id + self.recipe_user_id = recipe_user_id class ReqResInfo: @@ -148,6 +152,7 @@ def __init__(self): async def create_new_session( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -161,6 +166,7 @@ def get_global_claim_validators( self, tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: @@ -188,22 +194,13 @@ async def get_session( async def validate_claims( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Dict[str, Any], claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> ClaimsValidationResult: pass - @abstractmethod - async def validate_claims_in_jwt_payload( - self, - user_id: str, - jwt_payload: JSONObject, - claim_validators: List[SessionClaimValidator], - user_context: Dict[str, Any], - ) -> ClaimsValidationResult: - pass - @abstractmethod async def refresh_session( self, @@ -224,6 +221,7 @@ async def revoke_session( async def revoke_all_sessions_for_user( self, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: str, revoke_across_all_tenants: bool, user_context: Dict[str, Any], @@ -234,6 +232,7 @@ async def revoke_all_sessions_for_user( async def get_all_session_handles_for_user( self, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: str, fetch_across_all_tenants: bool, user_context: Dict[str, Any], @@ -359,7 +358,7 @@ async def refresh_post( @abstractmethod async def signout_post( self, - session: Optional[SessionContainer], + session: SessionContainer, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[SignOutOkayResponse, GeneralErrorResponse]: @@ -404,6 +403,7 @@ def __init__( anti_csrf_token: Optional[str], session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, user_data_in_access_token: Optional[Dict[str, Any]], req_res_info: Optional[ReqResInfo], access_token_updated: bool, @@ -421,10 +421,7 @@ def __init__( self.req_res_info: Optional[ReqResInfo] = req_res_info self.access_token_updated = access_token_updated self.tenant_id = tenant_id - self.recipe_user_id = RecipeUserId( - user_id - ) # TODO: change me to be based on input arg. - + self.recipe_user_id = recipe_user_id self.response_mutators: List[ResponseMutator] = [] @abstractmethod diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 6fa02621c..7e4eb4799 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -263,7 +263,7 @@ async def handle_error( response, self, request, user_context ) return await self.config.error_handlers.on_token_theft_detected( - request, err.session_handle, err.user_id, response + request, err.session_handle, err.user_id, err.recipe_user_id, response ) if isinstance(err, InvalidClaimsError): log_debug_message("errorHandler: returning INVALID_CLAIMS") diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 1c09a1fc0..225622c75 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -29,7 +29,6 @@ AccessTokenObj, ClaimsValidationResult, GetClaimValueOkResult, - JSONObject, RecipeInterface, RegenerateAccessTokenOkResult, SessionClaim, @@ -62,6 +61,7 @@ def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo): async def create_new_session( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -73,7 +73,7 @@ async def create_new_session( result = await session_functions.create_new_session( self, tenant_id, - user_id, + recipe_user_id, disable_anti_csrf is True, access_token_payload, session_data_in_database, @@ -96,6 +96,7 @@ async def create_new_session( result.antiCsrfToken, result.session.handle, result.session.userId, + recipe_user_id, payload, None, True, @@ -107,6 +108,7 @@ async def create_new_session( async def validate_claims( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Dict[str, Any], claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], @@ -128,7 +130,7 @@ async def validate_claims( value = await resolve( validator.claim.fetch_value( user_id, - RecipeUserId(user_id), + recipe_user_id, access_token_payload.get("tId", DEFAULT_TENANT_ID), user_context, ) @@ -152,21 +154,6 @@ async def validate_claims( return ClaimsValidationResult(invalid_claims, access_token_payload_update) - async def validate_claims_in_jwt_payload( - self, - user_id: str, - jwt_payload: JSONObject, - claim_validators: List[SessionClaimValidator], - user_context: Dict[str, Any], - ) -> ClaimsValidationResult: - invalid_claims = await validate_claims_in_payload( - claim_validators, - jwt_payload, - user_context, - ) - - return ClaimsValidationResult(invalid_claims) - async def get_session( self, access_token: Optional[str], @@ -267,6 +254,7 @@ async def get_session( anti_csrf_token, response.session.handle, response.session.userId, + response.session.recipe_user_id, payload, None, access_token_updated, @@ -321,6 +309,7 @@ async def refresh_session( response.antiCsrfToken, response.session.handle, response.session.userId, + response.session.recipe_user_id, user_data_in_access_token=payload, req_res_info=None, access_token_updated=True, @@ -339,23 +328,35 @@ async def revoke_session( async def revoke_all_sessions_for_user( self, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: Optional[str], revoke_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.revoke_all_sessions_for_user( - self, user_id, tenant_id, revoke_across_all_tenants, user_context + self, + user_id, + revoke_sessions_for_linked_accounts, + tenant_id, + revoke_across_all_tenants, + user_context, ) async def get_all_session_handles_for_user( self, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: Optional[str], fetch_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.get_all_session_handles_for_user( - self, user_id, tenant_id, fetch_across_all_tenants, user_context + self, + user_id, + fetch_sessions_for_linked_accounts, + tenant_id, + fetch_across_all_tenants, + user_context, ) async def revoke_multiple_sessions( @@ -421,7 +422,7 @@ async def fetch_and_set_claim( access_token_payload_update = await claim.build( session_info.user_id, - RecipeUserId(session_info.user_id), + session_info.recipe_user_id, session_info.tenant_id, user_context, ) @@ -461,6 +462,7 @@ def get_global_claim_validators( self, tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: @@ -502,6 +504,7 @@ async def regenerate_access_token( session = SessionObj( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ) diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index 1a75b4183..cf4ea2b45 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -210,6 +210,7 @@ async def assert_claims( validate_claim_res = await self.recipe_implementation.validate_claims( self.get_user_id(user_context), + self.get_recipe_user_id(user_context), self.get_access_token_payload(user_context), claim_validators, user_context, @@ -236,9 +237,9 @@ async def fetch_and_set_claim( user_context = {} update = await claim.build( - self.get_user_id(), - self.get_recipe_user_id(), - self.get_tenant_id(), + self.get_user_id(user_context=user_context), + self.get_recipe_user_id(user_context=user_context), + self.get_tenant_id(user_context=user_context), user_context, ) return await self.merge_into_access_token_payload(update, user_context) diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index 5037ad45f..a09200633 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional from supertokens_python.recipe.session.interfaces import SessionInformationResult +from supertokens_python.types import RecipeUserId from .access_token import get_info_from_access_token from .jwt import ParsedJWTInfo @@ -40,9 +41,17 @@ class CreateOrRefreshAPIResponseSession: - def __init__(self, handle: str, userId: str, userDataInJWT: Any, tenant_id: str): + def __init__( + self, + handle: str, + userId: str, + recipe_user_id: RecipeUserId, + userDataInJWT: Any, + tenant_id: str, + ): self.handle = handle self.userId = userId + self.recipe_user_id = recipe_user_id self.userDataInJWT = userDataInJWT self.tenant_id = tenant_id @@ -66,12 +75,14 @@ def __init__( self, handle: str, userId: str, + recipe_user_id: RecipeUserId, userDataInJWT: Dict[str, Any], expiryTime: int, tenant_id: str, ) -> None: self.handle = handle self.userId = userId + self.recipe_user_id = recipe_user_id self.userDataInJWT = userDataInJWT self.expiryTime = expiryTime self.tenant_id = tenant_id @@ -97,7 +108,7 @@ def __init__( async def create_new_session( recipe_implementation: RecipeImplementation, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, disable_anti_csrf: bool, access_token_payload: Union[None, Dict[str, Any]], session_data_in_database: Union[None, Dict[str, Any]], @@ -115,7 +126,7 @@ async def create_new_session( response = await recipe_implementation.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/session"), { - "userId": user_id, + "userId": recipe_user_id.get_as_string(), "userDataInJWT": access_token_payload, "userDataInDatabase": session_data_in_database, "useDynamicSigningKey": recipe_implementation.config.use_dynamic_access_token_signing_key, @@ -128,6 +139,7 @@ async def create_new_session( CreateOrRefreshAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ), @@ -264,6 +276,7 @@ async def get_session( GetSessionAPIResponseSession( access_token_info["sessionHandle"], access_token_info["userId"], + RecipeUserId(access_token_info["recipeUserId"]), access_token_info["userData"], access_token_info["expiryTime"], access_token_info["tenantId"], @@ -291,6 +304,7 @@ async def get_session( GetSessionAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], ( response.get("accessToken", {}).get( @@ -369,6 +383,7 @@ async def refresh_session( CreateOrRefreshAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ), @@ -393,13 +408,16 @@ async def refresh_session( "refreshSession: Returning TOKEN_THEFT_DETECTED because of core response" ) raise_token_theft_exception( - response["session"]["userId"], response["session"]["handle"] + response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), + response["session"]["handle"], ) async def revoke_all_sessions_for_user( recipe_implementation: RecipeImplementation, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: Optional[str], revoke_across_all_tenants: bool, user_context: Optional[Dict[str, Any]], @@ -410,13 +428,21 @@ async def revoke_all_sessions_for_user( if revoke_across_all_tenants: response = await recipe_implementation.querier.send_post_request( NormalisedURLPath("/recipe/session/remove"), - {"userId": user_id, "revokeAcrossAllTenants": revoke_across_all_tenants}, + { + "userId": user_id, + "revokeAcrossAllTenants": revoke_across_all_tenants, + "revokeSessionsForLinkedAccounts": revoke_sessions_for_linked_accounts, + }, user_context=user_context, ) else: response = await recipe_implementation.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/session/remove"), - {"userId": user_id, "revokeAcrossAllTenants": revoke_across_all_tenants}, + { + "userId": user_id, + "revokeAcrossAllTenants": revoke_across_all_tenants, + "revokeSessionsForLinkedAccounts": revoke_sessions_for_linked_accounts, + }, user_context=user_context, ) return response["sessionHandlesRevoked"] @@ -425,6 +451,7 @@ async def revoke_all_sessions_for_user( async def get_all_session_handles_for_user( recipe_implementation: RecipeImplementation, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: Optional[str], fetch_across_all_tenants: bool, user_context: Optional[Dict[str, Any]], @@ -435,13 +462,21 @@ async def get_all_session_handles_for_user( if fetch_across_all_tenants: response = await recipe_implementation.querier.send_get_request( NormalisedURLPath("/recipe/session/user"), - {"userId": user_id, "fetchAcrossAllTenants": fetch_across_all_tenants}, + { + "userId": user_id, + "fetchAcrossAllTenants": fetch_across_all_tenants, + "fetchSessionsForAllLinkedAccounts": fetch_sessions_for_linked_accounts, + }, user_context=user_context, ) else: response = await recipe_implementation.querier.send_get_request( NormalisedURLPath(f"{tenant_id}/recipe/session/user"), - {"userId": user_id, "fetchAcrossAllTenants": fetch_across_all_tenants}, + { + "userId": user_id, + "fetchAcrossAllTenants": fetch_across_all_tenants, + "fetchSessionsForAllLinkedAccounts": fetch_sessions_for_linked_accounts, + }, user_context=user_context, ) return response["sessionHandles"] @@ -521,6 +556,7 @@ async def get_session_information( return SessionInformationResult( response["sessionHandle"], response["userId"], + RecipeUserId(response["recipeUserId"]), response["userDataInDatabase"], response["expiry"], response["userDataInJWT"], diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 0e211627a..b1e297c2d 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -239,6 +239,7 @@ async def create_new_session_in_request( recipe_instance: SessionRecipe, access_token_payload: Dict[str, Any], user_id: str, + recipe_user_id: RecipeUserId, config: SessionConfig, app_info: AppInfo, session_data_in_database: Dict[str, Any], @@ -268,9 +269,7 @@ async def create_new_session_in_request( del final_access_token_payload[prop] for claim in claims_added_by_other_recipes: - update = await claim.build( - user_id, RecipeUserId(user_id), tenant_id, user_context - ) + update = await claim.build(user_id, recipe_user_id, tenant_id, user_context) final_access_token_payload.update(update) log_debug_message("createNewSession: Access token payload built") @@ -316,6 +315,7 @@ async def create_new_session_in_request( disable_anti_csrf = output_transfer_method == "header" session = await recipe_instance.recipe_implementation.create_new_session( user_id, + recipe_user_id, final_access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/supertokens_python/recipe/session/syncio/__init__.py b/supertokens_python/recipe/session/syncio/__init__.py index 71c87545e..50eec7a11 100644 --- a/supertokens_python/recipe/session/syncio/__init__.py +++ b/supertokens_python/recipe/session/syncio/__init__.py @@ -30,17 +30,17 @@ SessionInformationResult, SessionClaimValidator, SessionClaim, - JSONObject, ClaimsValidationResult, SessionDoesNotExistError, GetClaimValueOkResult, ) +from supertokens_python.recipe.session.recipe_implementation import RecipeUserId def create_new_session( request: Any, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -51,9 +51,9 @@ def create_new_session( return sync( async_create_new_session( - tenant_id=tenant_id, request=request, - user_id=user_id, + tenant_id=tenant_id, + recipe_user_id=recipe_user_id, access_token_payload=access_token_payload, session_data_in_database=session_data_in_database, user_context=user_context, @@ -63,7 +63,7 @@ def create_new_session( def create_new_session_without_request_response( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, disable_anti_csrf: bool = False, @@ -76,7 +76,7 @@ def create_new_session_without_request_response( return sync( async_create_new_session_without_request_response( tenant_id, - user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -187,6 +187,7 @@ def revoke_session( def revoke_all_sessions_for_user( user_id: str, + revoke_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: @@ -194,11 +195,19 @@ def revoke_all_sessions_for_user( revoke_all_sessions_for_user as async_revoke_all_sessions_for_user, ) - return sync(async_revoke_all_sessions_for_user(user_id, tenant_id, user_context)) + return sync( + async_revoke_all_sessions_for_user( + user_id, + revoke_sessions_for_linked_accounts, + tenant_id, + user_context, + ) + ) def get_all_session_handles_for_user( user_id: str, + fetch_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: @@ -207,7 +216,12 @@ def get_all_session_handles_for_user( ) return sync( - async_get_all_session_handles_for_user(user_id, tenant_id, user_context) + async_get_all_session_handles_for_user( + user_id, + fetch_sessions_for_linked_accounts, + tenant_id, + user_context, + ) ) @@ -365,30 +379,3 @@ def validate_claims_for_session_handle( session_handle, override_global_claim_validators, user_context ) ) - - -def validate_claims_in_jwt_payload( - tenant_id: str, - user_id: str, - jwt_payload: JSONObject, - override_global_claim_validators: Optional[ - Callable[ - [List[SessionClaimValidator], str, Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.session.asyncio import ( - validate_claims_in_jwt_payload as async_validate_claims_in_jwt_payload, - ) - - return sync( - async_validate_claims_in_jwt_payload( - tenant_id, - user_id, - jwt_payload, - override_global_claim_validators, - user_context, - ) - ) diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 5577147fa..96e5c43a4 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -30,7 +30,7 @@ send_non_200_response_with_message, ) -from ...types import MaybeAwaitable +from ...types import MaybeAwaitable, RecipeUserId from .constants import AUTH_MODE_HEADER_KEY, SESSION_REFRESH from .exceptions import ClaimValidationError @@ -100,7 +100,7 @@ class ErrorHandlers: def __init__( self, on_token_theft_detected: Callable[ - [BaseRequest, str, str, BaseResponse], + [BaseRequest, str, str, RecipeUserId, BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], on_try_refresh_token: Callable[ @@ -131,10 +131,13 @@ async def on_token_theft_detected( request: BaseRequest, session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, response: BaseResponse, ) -> BaseResponse: return await resolve( - self.__on_token_theft_detected(request, session_handle, user_id, response) + self.__on_token_theft_detected( + request, session_handle, user_id, recipe_user_id, response + ) ) async def on_try_refresh_token( @@ -182,7 +185,7 @@ def __init__( on_token_theft_detected: Union[ None, Callable[ - [BaseRequest, str, str, BaseResponse], + [BaseRequest, str, str, RecipeUserId, BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], ] = None, @@ -261,7 +264,11 @@ async def default_try_refresh_token_callback( async def default_token_theft_detected_callback( - _: BaseRequest, session_handle: str, __: str, response: BaseResponse + _: BaseRequest, + session_handle: str, + __: str, + ___: RecipeUserId, + response: BaseResponse, ) -> BaseResponse: from .recipe import SessionRecipe @@ -576,6 +583,7 @@ async def get_required_claim_validators( SessionRecipe.get_instance().recipe_implementation.get_global_claim_validators( session.get_tenant_id(), session.get_user_id(), + session.get_recipe_user_id(), claim_validators_added_by_other_recipes, user_context, ) diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index 97974c883..6a49d423d 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -135,7 +135,7 @@ async def sign_in_up_post( session = await create_new_session( request=api_options.request, tenant_id=tenant_id, - user_id=user.user_id, + recipe_user_id=RecipeUserId(user.user_id), user_context=user_context, ) diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index 6305a1b2e..9e221abc6 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -41,6 +41,7 @@ from supertokens_python.recipe.session.framework.django.asyncio import verify_session import pytest +from supertokens_python.types import RecipeUserId from tests.utils import ( clean_st, reset, @@ -86,7 +87,7 @@ def get_cookies(response: HttpResponse) -> Dict[str, Any]: async def create_new_session_view(request: HttpRequest): - await create_new_session(request, "public", "user_id") + await create_new_session(request, "public", RecipeUserId("user_id")) return JsonResponse({"foo": "bar"}) @@ -979,7 +980,7 @@ async def test_that_verify_session_return_401_if_access_token_is_not_sent_and_mi # Create a session and get access token s = await create_new_session_without_request_response( - "public", "userId", {}, {} + "public", RecipeUserId("userId"), {}, {} ) access_token = s.get_access_token() headers = {"HTTP_AUTHORIZATION": "Bearer " + access_token} diff --git a/tests/Fastapi/test_fastapi.py b/tests/Fastapi/test_fastapi.py index 33f3b5013..6a76a3168 100644 --- a/tests/Fastapi/test_fastapi.py +++ b/tests/Fastapi/test_fastapi.py @@ -16,6 +16,7 @@ from fastapi import Depends, FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -91,7 +92,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") @@ -135,12 +136,12 @@ async def custom_logout(request: Request): # type: ignore @app.post("/create") async def _create(request: Request): # type: ignore - await create_new_session(request, "public", "userId", {}, {}) + await create_new_session(request, "public", RecipeUserId("userId"), {}, {}) return "" @app.post("/create-throw") async def _create_throw(request: Request): # type: ignore - await create_new_session(request, "public", "userId", {}, {}) + await create_new_session(request, "public", RecipeUserId("userId"), {}, {}) raise UnauthorisedError("unauthorised") return TestClient(app) diff --git a/tests/Flask/test_flask.py b/tests/Flask/test_flask.py index d87292fe2..ec45aadc3 100644 --- a/tests/Flask/test_flask.py +++ b/tests/Flask/test_flask.py @@ -32,6 +32,7 @@ refresh_session, revoke_session, ) +from supertokens_python.types import RecipeUserId from tests.Flask.utils import extract_all_cookies from tests.utils import ( TEST_ACCESS_TOKEN_MAX_AGE_CONFIG_KEY, @@ -193,7 +194,7 @@ def t(): # type: ignore @app.route("/login") # type: ignore def login(): # type: ignore user_id = "userId" - create_new_session(request, "public", user_id, {}, {}) + create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return jsonify({"userId": user_id, "session": "ssss"}) @@ -753,7 +754,7 @@ def test_api(): # type: ignore @app.route("/login") # type: ignore def login(): # type: ignore user_id = "userId" - s = create_new_session(request, "public", user_id, {}, {}) + s = create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return jsonify({"user": s.get_user_id()}) @app.route("/ping") # type: ignore @@ -832,7 +833,9 @@ def test_that_verify_session_return_401_if_access_token_is_not_sent_and_middlewa assert res.status_code == 401 assert res.json == {"message": "unauthorised"} - s = create_new_session_without_request_response("public", "userId", {}, {}) + s = create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) res = client.get( "/verify", headers={"Authorization": "Bearer " + s.get_access_token()} ) @@ -882,7 +885,7 @@ def _(_): @app.route("/create-session") # type: ignore def create_session_api(): # type: ignore - create_new_session(request, "public", "userId", {}, {}) + create_new_session(request, "public", RecipeUserId("userId"), {}, {}) return jsonify({}) return app diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index e5275d227..39a149043 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -483,7 +483,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index fcaab51ba..81932186d 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -538,7 +538,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 6e1618e30..9b7cfa3ee 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -487,7 +487,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): diff --git a/tests/emailpassword/test_emailexists.py b/tests/emailpassword/test_emailexists.py index a8425b027..748dde85d 100644 --- a/tests/emailpassword/test_emailexists.py +++ b/tests/emailpassword/test_emailexists.py @@ -16,6 +16,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -50,7 +51,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 34c316032..0bf066d68 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -85,7 +85,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") diff --git a/tests/emailpassword/test_passwordreset.py b/tests/emailpassword/test_passwordreset.py index 3b76544c0..ff00f46d8 100644 --- a/tests/emailpassword/test_passwordreset.py +++ b/tests/emailpassword/test_passwordreset.py @@ -18,6 +18,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, raises from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -56,7 +57,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") diff --git a/tests/emailpassword/test_signin.py b/tests/emailpassword/test_signin.py index 1757587f4..930ed45d4 100644 --- a/tests/emailpassword/test_signin.py +++ b/tests/emailpassword/test_signin.py @@ -16,6 +16,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -60,7 +61,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index b2d725449..e2472e21c 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -48,6 +48,7 @@ merge_into_access_token_payload, ) from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -279,6 +280,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -290,6 +292,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 6cfc59c76..0ec1fa4d2 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -48,6 +48,7 @@ from supertokens_python.recipe.session.asyncio import merge_into_access_token_payload from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -282,6 +283,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -293,6 +295,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_async/mysite/settings.py b/tests/frontendIntegration/drf_async/mysite/settings.py index f30de3e1f..0ba39a229 100644 --- a/tests/frontendIntegration/drf_async/mysite/settings.py +++ b/tests/frontendIntegration/drf_async/mysite/settings.py @@ -16,6 +16,8 @@ # Build paths inside the project like this: os.path.join(BASE_DIR, ...) from corsheaders.defaults import default_headers +from supertokens_python.types import RecipeUserId + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Quick-start development settings - unsuitable for production @@ -173,6 +175,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -184,6 +187,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index ddab05beb..757ae922f 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -54,6 +54,7 @@ from supertokens_python.recipe.session.asyncio import merge_into_access_token_payload from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -306,6 +307,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -317,6 +319,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_sync/mysite/settings.py b/tests/frontendIntegration/drf_sync/mysite/settings.py index 2f2c39b2a..bfcf33d19 100644 --- a/tests/frontendIntegration/drf_sync/mysite/settings.py +++ b/tests/frontendIntegration/drf_sync/mysite/settings.py @@ -16,6 +16,8 @@ # Build paths inside the project like this: os.path.join(BASE_DIR, ...) from corsheaders.defaults import default_headers +from supertokens_python.types import RecipeUserId + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Quick-start development settings - unsuitable for production @@ -173,6 +175,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -184,6 +187,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index 148c7ce98..bbcbbb5d3 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -54,6 +54,7 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -306,6 +307,7 @@ def functions_override_session(param: RecipeInterface): def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -317,6 +319,7 @@ def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index 013e89b32..be1d31f96 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -52,6 +52,7 @@ RecipeInterface, ) from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.querier import Querier @@ -136,6 +137,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -147,6 +149,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index f99577eb6..e829b61fc 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -44,6 +44,7 @@ ) from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -162,6 +163,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -173,6 +175,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -295,7 +298,7 @@ def login_options(): @app.route("/login", methods=["POST"]) # type: ignore def login(): user_id: str = request.get_json()["userId"] # type: ignore - _session = create_new_session(request, "public", user_id) + _session = create_new_session(request, "public", RecipeUserId(user_id)) return _session.get_user_id() diff --git a/tests/jwt/test_get_JWKS.py b/tests/jwt/test_get_JWKS.py index 0cde0912c..20c945ab1 100644 --- a/tests/jwt/test_get_JWKS.py +++ b/tests/jwt/test_get_JWKS.py @@ -25,6 +25,7 @@ from supertokens_python.recipe import jwt from supertokens_python.recipe.jwt.interfaces import APIInterface, RecipeInterface from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.types import RecipeUserId from tests.utils import clean_st, reset, setup_st, start_st @@ -50,7 +51,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} return TestClient(app) diff --git a/tests/sessions/claims/test_assert_claims.py b/tests/sessions/claims/test_assert_claims.py index bb7cde83f..cb6942b7d 100644 --- a/tests/sessions/claims/test_assert_claims.py +++ b/tests/sessions/claims/test_assert_claims.py @@ -16,6 +16,7 @@ ) from supertokens_python.recipe.session.session_class import Session from supertokens_python import init +from supertokens_python.types import RecipeUserId from tests.utils import ( get_st_init_args, setup_function, @@ -58,6 +59,7 @@ async def test_should_not_throw_for_empty_array(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, False, # access_token_updated @@ -96,6 +98,7 @@ async def test_should_call_validate_with_the_same_payload_object(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), payload, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -142,5 +145,7 @@ async def test_assert_claims_should_work(): start_st() validator = TrueClaim.validators.is_true(1) - s = await create_new_session_without_request_response("public", "userid", {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userid"), {} + ) await s.assert_claims([validator]) diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index a6e5a2195..1207f8335 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -4,6 +4,7 @@ from supertokens_python.framework import BaseRequest from supertokens_python.recipe import session from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.types import RecipeUserId from tests.utils import ( setup_function, teardown_function, @@ -28,7 +29,7 @@ async def test_create_access_token_payload_with_session_claims(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() assert len(payload) == 10 @@ -41,7 +42,7 @@ async def test_should_create_access_token_payload_with_session_claims_with_an_no start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() assert len(payload) == 9 @@ -66,7 +67,7 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() assert len(payload) == 11 diff --git a/tests/sessions/claims/test_fetch_and_set_claim.py b/tests/sessions/claims/test_fetch_and_set_claim.py index 8e8d3ccad..66c0c2d50 100644 --- a/tests/sessions/claims/test_fetch_and_set_claim.py +++ b/tests/sessions/claims/test_fetch_and_set_claim.py @@ -3,6 +3,7 @@ from pytest import mark from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import NoneClaim, TrueClaim from tests.utils import AsyncMock, MagicMock @@ -28,6 +29,7 @@ async def test_should_not_change_if_claim_fetch_value_returns_none(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, False, # access_token_updated @@ -60,6 +62,7 @@ async def test_should_update_if_claim_fetch_value_returns_value(timestamp: int): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated diff --git a/tests/sessions/claims/test_get_claim_value.py b/tests/sessions/claims/test_get_claim_value.py index cea4c1afa..93d6db915 100644 --- a/tests/sessions/claims/test_get_claim_value.py +++ b/tests/sessions/claims/test_get_claim_value.py @@ -14,6 +14,7 @@ GetClaimValueOkResult, SessionDoesNotExistError, ) +from supertokens_python.types import RecipeUserId from tests.utils import setup_function, teardown_function, start_st, st_init_common_args from .utils import TrueClaim, get_st_init_args @@ -28,7 +29,7 @@ async def test_should_get_the_right_value(): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) res = await s.get_claim_value(TrueClaim) assert res is True @@ -39,7 +40,9 @@ async def test_should_get_the_right_value_using_session_handle(): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) res = await get_claim_value(s.get_handle(), TrueClaim) assert isinstance(res, GetClaimValueOkResult) diff --git a/tests/sessions/claims/test_remove_claim.py b/tests/sessions/claims/test_remove_claim.py index b448af1ae..a954cacbd 100644 --- a/tests/sessions/claims/test_remove_claim.py +++ b/tests/sessions/claims/test_remove_claim.py @@ -11,6 +11,7 @@ remove_claim, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, get_st_init_args from tests.utils import AsyncMock, setup_function, start_st, teardown_function @@ -37,6 +38,7 @@ async def test_should_attempt_to_set_claim_to_none(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -56,7 +58,9 @@ async def test_should_clear_previously_set_claim(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) payload = s.get_access_token_payload() @@ -68,7 +72,9 @@ async def test_should_clear_previously_set_claim_using_handle(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) payload = s.get_access_token_payload() assert payload["st-true"] == {"v": True, "t": timestamp} diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 170cd66a4..91c518f8e 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -10,6 +10,7 @@ set_claim_value, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, get_st_init_args from tests.utils import AsyncMock, setup_function, start_st, teardown_function @@ -38,6 +39,7 @@ async def test_should_merge_the_right_value(timestamp: int): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -57,7 +59,7 @@ async def test_should_overwrite_claim_value(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() assert len(payload) == 10 @@ -76,7 +78,7 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int) start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() assert len(payload) == 10 diff --git a/tests/sessions/claims/test_validate_claims_for_session_handle.py b/tests/sessions/claims/test_validate_claims_for_session_handle.py index 72153ecee..71c62c69d 100644 --- a/tests/sessions/claims/test_validate_claims_for_session_handle.py +++ b/tests/sessions/claims/test_validate_claims_for_session_handle.py @@ -13,6 +13,7 @@ ClaimsValidationResult, SessionDoesNotExistError, ) +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import ( get_st_init_args, NoneClaim, @@ -31,7 +32,7 @@ async def test_should_return_the_right_validation_errors(): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) failing_validator = NoneClaim.validators.has_value(True) res = await validate_claims_for_session_handle( diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index ae373e064..12f3b5461 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -28,6 +28,7 @@ ClaimsValidationResult, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, NoneClaim from tests.utils import ( setup_function, @@ -50,8 +51,9 @@ def st_init_generator_with_overriden_global_validators( ): def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_get_global_claim_validators( - _user_id: str, _tenant_id: str, + _user_id: str, + _recipe_user_id: RecipeUserId, _claim_validators_added_by_other_recipes: List[SessionClaimValidator], _user_context: Dict[str, Any], ): @@ -77,8 +79,9 @@ async def new_get_global_claim_validators( def st_init_generator_with_claim_validator(claim_validator: SessionClaimValidator): def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_get_global_claim_validators( - _user_id: str, _tenant_id: str, + _user_id: str, + _recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], _user_context: Dict[str, Any], ): @@ -138,13 +141,13 @@ async def fastapi_client(): @app.post("/login") async def _login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/create-with-claim") async def _create_with_claim(request: Request): # type: ignore user_id = "userId" - _ = await create_new_session(request, "public", user_id, {}, {}) + _ = await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) key: str = (await request.json())["key"] # PrimitiveClaim(key, fetch_value="Value").add_to_session(session, "value") return {"userId": key} @@ -377,6 +380,7 @@ async def test_should_reject_if_assert_claims_returns_an_error( None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -426,6 +430,7 @@ async def test_should_allow_if_assert_claims_returns_no_error( None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 529cd05a8..b55fd7f6e 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -24,6 +24,7 @@ def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[None, Dict[str, Any]], session_data_in_database: Union[None, Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -43,6 +44,7 @@ async def new_create_new_session( return await oi_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/sessions/test_access_token_version.py b/tests/sessions/test_access_token_version.py index 55e0448be..4000bebd5 100644 --- a/tests/sessions/test_access_token_version.py +++ b/tests/sessions/test_access_token_version.py @@ -17,6 +17,7 @@ validate_access_token_structure, ) from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.types import RecipeUserId from tests.utils import get_st_init_args, setup_function, start_st, teardown_function _ = setup_function # type:ignore @@ -30,7 +31,9 @@ async def test_access_token_v4(): start_st() access_token = ( - await create_new_session_without_request_response("public", "user-id") + await create_new_session_without_request_response( + "public", RecipeUserId("user-id") + ) ).get_access_token() s = await get_session_without_request_response(access_token) assert s is not None @@ -79,7 +82,9 @@ async def _create(request: Request): # type: ignore except Exception: pass - session = await create_new_session(request, "public", "userId", body, {}) + session = await create_new_session( + request, "public", RecipeUserId("userId"), body, {} + ) return {"message": True, "sessionHandle": session.get_handle()} @fast.get("/merge-into-payload") @@ -212,7 +217,7 @@ async def test_ignore_protected_props_in_create_session(): s = await create_new_session_without_request_response( "public", - "user1", + RecipeUserId("user1"), {"foo": "bar"}, ) payload = parse_jwt_without_signature_verification(s.access_token).payload @@ -220,7 +225,7 @@ async def test_ignore_protected_props_in_create_session(): assert payload["sub"] == "user1" s2 = await create_new_session_without_request_response( - "public", "user2", s.get_access_token_payload() + "public", RecipeUserId("user2"), s.get_access_token_payload() ) payload = parse_jwt_without_signature_verification(s2.access_token).payload assert payload["foo"] == "bar" diff --git a/tests/sessions/test_auth_mode.py b/tests/sessions/test_auth_mode.py index 0a01e0a0b..69b0b1b86 100644 --- a/tests/sessions/test_auth_mode.py +++ b/tests/sessions/test_auth_mode.py @@ -2,6 +2,7 @@ from typing_extensions import Literal from fastapi import Depends, FastAPI, Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import init @@ -33,7 +34,9 @@ async def app(): @fast.post("/create") async def _create(request: Request): # type: ignore body = await request.json() - session = await create_new_session(request, "public", "userId", body, {}) + session = await create_new_session( + request, "public", RecipeUserId("userId"), body, {} + ) return {"message": True, "sessionHandle": session.get_handle()} @fast.get("/update-payload") diff --git a/tests/sessions/test_jwks.py b/tests/sessions/test_jwks.py index 7affdbf31..7b800210d 100644 --- a/tests/sessions/test_jwks.py +++ b/tests/sessions/test_jwks.py @@ -15,6 +15,7 @@ get_session_without_request_response, ) from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.types import RecipeUserId from supertokens_python.utils import get_timestamp_ms from tests.utils import ( get_st_init_args, @@ -85,7 +86,9 @@ async def test_that_jwks_is_fetched_as_expected(caplog: LogCaptureFixture): assert next(well_known_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) time.sleep(jwk_max_age_sec) tokens = s.get_all_session_tokens_dangerously() @@ -173,7 +176,9 @@ async def test_that_jwks_are_refresh_if_kid_is_unknown(caplog: LogCaptureFixture assert next(well_known_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert next(well_known_count) == 0 @@ -188,7 +193,9 @@ async def test_that_jwks_are_refresh_if_kid_is_unknown(caplog: LogCaptureFixture assert next(well_known_count) == 1 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert next(well_known_count) == 1 @@ -258,7 +265,9 @@ async def test_jwks_cache_logic(caplog: LogCaptureFixture): assert next(jwks_refresh_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert get_cached_keys() is None assert next(jwks_refresh_count) == 0 @@ -383,7 +392,9 @@ async def test_that_jwks_returns_from_cache_correctly(caplog: LogCaptureFixture) init(**get_st_init_args(recipe_list=[session.init(jwks_refresh_interval_sec=2)])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert get_cached_keys() is None assert next(jwk_refresh_count) == 0 assert next(returned_from_cache_count) == 0 @@ -478,7 +489,9 @@ async def test_session_verification_of_jwt_based_on_session_payload( init(**get_st_init_args(recipe_list=[session.init()])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -500,7 +513,9 @@ async def test_session_verification_of_jwt_based_on_session_payload_with_check_d init(**get_st_init_args(recipe_list=[session.init()])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -525,7 +540,9 @@ async def test_session_verification_of_jwt_with_dynamic_signing_key(): ) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -645,7 +662,7 @@ async def client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "test" - s = await create_new_session(request, "public", user_id, {}, {}) + s = await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"jwt": s.get_access_token()} @app.get("/sessioninfo") diff --git a/tests/sessions/test_session_error_handlers.py b/tests/sessions/test_session_error_handlers.py index 2dc052a97..0cb569297 100644 --- a/tests/sessions/test_session_error_handlers.py +++ b/tests/sessions/test_session_error_handlers.py @@ -17,6 +17,7 @@ from fastapi import FastAPI from fastapi.requests import Request from pytest import fixture, mark +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import init @@ -66,7 +67,7 @@ async def test_try_refresh(_request: Request): # type: ignore @app.post("/test/token-theft") async def test_token_theft(_request: Request): # type: ignore - raise TokenTheftError("", "") + raise TokenTheftError("", RecipeUserId(""), "") @app.post("/test/claim-validation") async def test_claim_validation(_request: Request): # type: ignore @@ -88,7 +89,11 @@ def unauthorised_f(_req: BaseRequest, _message: str, res: BaseResponse): return res def token_theft_f( - _req: BaseRequest, _session_handle: str, _user_id: str, res: BaseResponse + _req: BaseRequest, + _session_handle: str, + _user_id: str, + _rid: RecipeUserId, + res: BaseResponse, ): res.set_status_code(403) res.set_json_content({"message": "token theft detected from errorHandler"}) diff --git a/tests/sessions/test_use_dynamic_signing_key_switching.py b/tests/sessions/test_use_dynamic_signing_key_switching.py index 970a28df7..9dc3fc84b 100644 --- a/tests/sessions/test_use_dynamic_signing_key_switching.py +++ b/tests/sessions/test_use_dynamic_signing_key_switching.py @@ -7,6 +7,7 @@ refresh_session_without_request_response, ) from supertokens_python.recipe.session.session_class import SessionContainer +from supertokens_python.types import RecipeUserId from tests.utils import ( get_st_init_args, setup_function, @@ -31,7 +32,7 @@ async def test_dynamic_key_switching(): # Create a new session without an actual HTTP request-response flow create_res: SessionContainer = await create_new_session_without_request_response( - "public", "test-user-id", {"tokenProp": True}, {"dbProp": True} + "public", RecipeUserId("test-user-id"), {"tokenProp": True}, {"dbProp": True} ) # Extract session tokens for further testing @@ -67,7 +68,7 @@ async def test_refresh_session(): # Create a new session without an actual HTTP request-response flow create_res: SessionContainer = await create_new_session_without_request_response( - "public", "test-user-id", {"tokenProp": True}, {"dbProp": True} + "public", RecipeUserId("test-user-id"), {"tokenProp": True}, {"dbProp": True} ) # Extract session tokens for further testing diff --git a/tests/test-server/session.py b/tests/test-server/session.py index df2f1cc7f..1d0574122 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -85,6 +85,7 @@ def assert_claims(): # type: ignore anti_csrf_token=None, # We don't have anti-csrf token in the input session_handle=data["session"]["sessionHandle"], user_id=data["session"]["userId"], + recipe_user_id=data["session"]["recipeUserId"], user_data_in_access_token=data["session"]["userDataInAccessToken"], req_res_info=None, # We don't have this information in the input access_token_updated=data["session"]["accessTokenUpdated"], @@ -143,6 +144,7 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore anti_csrf_token=None, # We don't have anti-csrf token in the input session_handle=data["session"]["sessionHandle"], user_id=data["session"]["userId"], + recipe_user_id=data["session"]["recipeUserId"], user_data_in_access_token=data["session"]["userDataInAccessToken"], req_res_info=None, # We don't have this information in the input access_token_updated=data["session"]["accessTokenUpdated"], diff --git a/tests/test_config.py b/tests/test_config.py index 1a8b30963..d8a64761b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -21,6 +21,7 @@ from supertokens_python.recipe.session.asyncio import create_new_session from typing import Optional, Dict, Any from supertokens_python.framework import BaseRequest +from supertokens_python.types import RecipeUserId from tests.utils import clean_st, reset, setup_st, start_st @@ -698,7 +699,7 @@ async def test_samesite_valid_config(): ["http://localhost:3000", "https://supertokensapi.io"], ["http://127.0.0.1:3000", "https://supertokensapi.io"], ] - for (website_domain, api_domain) in domain_combinations: + for website_domain, api_domain in domain_combinations: reset() clean_st() setup_st() @@ -724,7 +725,7 @@ async def test_samesite_invalid_config(): ["http://supertokens.io", "http://127.0.0.1:8000"], ["http://supertokens.io", "http://supertokensapi.io"], ] - for (website_domain, api_domain) in domain_combinations: + for website_domain, api_domain in domain_combinations: reset() clean_st() setup_st() @@ -744,7 +745,9 @@ async def test_samesite_invalid_config(): ) ], ) - await create_new_session("public", MagicMock(), "userId", {}, {}) + await create_new_session( + "public", MagicMock(), RecipeUserId("userId"), {}, {} + ) except Exception as e: assert ( str(e) diff --git a/tests/test_session.py b/tests/test_session.py index f42446095..1f88a1417 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -21,6 +21,7 @@ from fastapi.requests import Request from fastapi.responses import JSONResponse from pytest import fixture, mark +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from requests.cookies import cookiejar_from_dict # type: ignore @@ -105,7 +106,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", "", False, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(""), False, {}, {}, None ) assert response.session is not None @@ -210,7 +211,7 @@ async def test_creating_many_sessions_for_one_user_and_looping(): new_session = await create_new_session( s.recipe_implementation, "public", - "someUser", + RecipeUserId("someUser"), False, {"someKey": "someValue"}, {}, @@ -218,7 +219,7 @@ async def test_creating_many_sessions_for_one_user_and_looping(): ) access_tokens.append(new_session.accessToken.token) - session_handles = await get_all_session_handles_for_user("someUser", "public") + session_handles = await get_all_session_handles_for_user("someUser", True, "public") assert len(session_handles) == 7 @@ -280,7 +281,9 @@ async def home(_request: Request): # type: ignore @app.post("/create") async def create_api(request: Request): # type: ignore - await async_create_new_session(request, "public", "test-user", {}, {}) + await async_create_new_session( + request, "public", RecipeUserId("test-user"), {}, {} + ) return "" @app.post("/sessioninfo-optional") @@ -313,7 +316,7 @@ async def test_signout_api_works_even_if_session_is_deleted_after_creation( user_id = "user_id" response = await create_new_session( - s.recipe_implementation, "public", user_id, False, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), False, {}, {}, None ) session_handle = response.session.handle @@ -400,7 +403,9 @@ async def get_session_information( mock_response = MagicMock() - my_session = await async_create_new_session(mock_response, "public", "test_id") + my_session = await async_create_new_session( + mock_response, "public", RecipeUserId("test_id"), {} + ) data = await my_session.get_session_data_from_database() assert data == {"foo": "bar"} @@ -733,7 +738,9 @@ async def test_that_verify_session_doesnt_always_call_core(): # response = await create_new_session(s.recipe_implementation, "", False, {}, {}) - session1 = await create_new_session_without_request_response("public", "user-id") + session1 = await create_new_session_without_request_response( + "public", RecipeUserId("user-id") + ) assert session1 is not None assert session1.access_token != "" diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 5b51fd08d..55bd67ad3 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -15,6 +15,7 @@ from pathlib import Path from fastapi import FastAPI +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark @@ -126,6 +127,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -140,6 +142,7 @@ async def create_new_session( user_context["preCreateNewSession"] = True response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -240,6 +243,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -253,6 +257,7 @@ async def create_new_session( response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -363,6 +368,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -378,6 +384,7 @@ async def create_new_session( response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/thirdparty/test_emaildelivery.py b/tests/thirdparty/test_emaildelivery.py index 0ca617ac1..d42fd65c5 100644 --- a/tests/thirdparty/test_emaildelivery.py +++ b/tests/thirdparty/test_emaildelivery.py @@ -19,6 +19,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark @@ -148,7 +149,7 @@ async def test_email_verify_default_backward_compatibility( assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) user_id = resp.user.user_id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -222,7 +223,7 @@ async def test_email_verify_default_backward_compatibility_supress_error( assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) user_id = resp.user.user_id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -312,7 +313,7 @@ async def send_email( assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) user_id = resp.user.user_id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( @@ -391,7 +392,7 @@ async def send_email( user_id = resp.user.user_id assert isinstance(user_id, str) response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -531,7 +532,7 @@ async def send_email_override( user_id = resp.user.user_id assert isinstance(user_id, str) response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( diff --git a/tests/userroles/test_claims.py b/tests/userroles/test_claims.py index ed0160bea..455f915b7 100644 --- a/tests/userroles/test_claims.py +++ b/tests/userroles/test_claims.py @@ -19,6 +19,7 @@ from supertokens_python import init from supertokens_python.recipe import userroles, session from supertokens_python.recipe.session.exceptions import ClaimValidationError +from supertokens_python.types import RecipeUserId from tests.utils import ( start_st, setup_function, @@ -56,7 +57,7 @@ async def test_add_claims_to_session_without_config(): user_id = "userId" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert s.sync_get_claim_value(UserRoleClaim) == [] assert (await s.get_claim_value(PermissionClaim)) == [] @@ -78,7 +79,7 @@ async def test_claims_not_added_to_session_if_disabled(): user_id = "userId" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert (await s.get_claim_value(UserRoleClaim)) is None assert s.sync_get_claim_value(PermissionClaim) is None @@ -101,7 +102,7 @@ async def test_add_claims_to_session_with_values(): await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert s.sync_get_claim_value(UserRoleClaim) == [role] value: List[str] = await s.get_claim_value(PermissionClaim) # type: ignore assert sorted(value) == sorted(["a", "b"]) @@ -126,7 +127,7 @@ async def test_should_validate_roles(): await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await s.assert_claims([UserRoleClaim.validators.includes(role)]) with pytest.raises(Exception) as e: @@ -159,7 +160,7 @@ async def test_should_validate_roles_after_refetch(): role = "role" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) @@ -187,7 +188,7 @@ async def test_should_validate_permissions(): await create_new_role_or_add_permissions(role, permissions) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await s.assert_claims([PermissionClaim.validators.includes("a")]) with pytest.raises(Exception) as e: @@ -223,7 +224,7 @@ async def test_should_validate_permissions_after_refetch(): permissions = ["a", "b"] req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await create_new_role_or_add_permissions(role, permissions) await add_role_to_user("public", user_id, role) From 1cd37b7448e9a05df46c3efe9f8085efb991d659 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 13 Aug 2024 13:10:44 +0530 Subject: [PATCH 015/126] adds interfaces for mfa recipe --- .../recipe/multifactorauth/interfaces.py | 147 ++++++++++++++++++ .../recipe/multifactorauth/types.py | 60 +++++++ 2 files changed, 207 insertions(+) create mode 100644 supertokens_python/recipe/multifactorauth/interfaces.py create mode 100644 supertokens_python/recipe/multifactorauth/types.py diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py new file mode 100644 index 000000000..3d2708eae --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Dict, Any, Union, List, Callable, Awaitable + +from supertokens_python.types import AccountLinkingUser + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from ...supertokens import AppInfo + +from ...types import GeneralErrorResponse + +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.session import SessionContainer + from .types import MFARequirementList, AccountLinkingConfig + + +class RecipeInterface(ABC): + @abstractmethod + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + self, + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: MFARequirementList, + factors_set_up_for_user: Awaitable[Callable[[], List[str]]], + user_context: Dict[str, Any], + ) -> None: + pass + + @abstractmethod + async def get_mfa_requirements_for_auth( + self, + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Awaitable[Callable[[], AccountLinkingUser]], + factors_set_up_for_user: Awaitable[Callable[[], List[str]]], + required_secondary_factors_for_user: Awaitable[Callable[[], List[str]]], + required_secondary_factors_for_tenant: Awaitable[Callable[[], List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + pass + + @abstractmethod + async def mark_factor_as_complete_in_session( + self, + session: SessionContainer, + factor_id: str, + user_context: Dict[str, Any], + ) -> None: + pass + + @abstractmethod + async def get_factors_setup_for_user( + self, user: AccountLinkingUser, user_context: Dict[str, Any] + ) -> List[str]: + pass + + @abstractmethod + async def get_required_secondary_factors_for_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> List[str]: + pass + + @abstractmethod + async def add_to_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ) -> None: + pass + + @abstractmethod + async def remove_from_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ) -> None: + pass + + +class APIOptions: + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: AccountLinkingConfig, + recipe_implementation: RecipeInterface, + app_info: AppInfo, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config: AccountLinkingConfig = config + self.recipe_implementation: RecipeInterface = recipe_implementation + self.app_info = app_info + + +class APIInterface: + def __init__(self): + self.disable_resync_session_and_fetch_mfa_info_put = False + + @abstractmethod + async def resync_session_and_fetch_mfa_info_put( + self, + api_options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + pass + + +class NextFactors: + def __init__( + self, next: List[str], already_setup: List[str], allowed_to_setup: List[str] + ): + self.next = next + self.already_setup = already_setup + self.allowed_to_setup = allowed_to_setup + + +class ResyncSessionAndFetchMFAInfoPUTOkResult: + def __init__( + self, + factors: NextFactors, + emails: Dict[str, Union[List[str], None]], + phone_numbers: Dict[str, Union[List[str], None]], + ): + self.factors = factors + self.emails = emails + self.phone_numbers = phone_numbers + + status: str = "OK" diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py new file mode 100644 index 000000000..f325c9575 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union, List, Optional, Callable +from .interfaces import RecipeInterface, APIInterface + + +class MFARequirementList(List[Union[Dict[str, List[str]], str]]): + def __init__(self, *args: Union[str, Dict[str, List[str]]]): + super().__init__() + for arg in args: + if isinstance(arg, str): + self.append(arg) + else: + if "oneOf" in arg: + self.append({"oneOf": arg["oneOf"]}) + elif "allOfInAnyOrder" in arg: + self.append({"allOfInAnyOrder": arg["allOfInAnyOrder"]}) + else: + raise ValueError("Invalid dictionary format") + + +class MFAClaimValue: + c: Dict[str, Any] + v: bool + + def __init__(self, c: Dict[str, Any], v: bool): + self.c = c + self.v = v + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class AccountLinkingConfig: + def __init__( + self, + first_factors: Optional[List[str]], + override: OverrideConfig, + ): + self.first_factors = first_factors + self.override = override From 8b7e0c4df4c42f36835d789c288807e048b63ce2 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 13 Aug 2024 13:26:52 +0530 Subject: [PATCH 016/126] more types --- .../recipe/multifactorauth/interfaces.py | 6 ++-- .../recipe/multifactorauth/types.py | 12 ++++++- .../recipe/multifactorauth/utils.py | 36 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 supertokens_python/recipe/multifactorauth/utils.py diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 3d2708eae..22de2d2c0 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer - from .types import MFARequirementList, AccountLinkingConfig + from .types import MFARequirementList, MultiFactorAuthConfig class RecipeInterface(ABC): @@ -98,14 +98,14 @@ def __init__( request: BaseRequest, response: BaseResponse, recipe_id: str, - config: AccountLinkingConfig, + config: MultiFactorAuthConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, ): self.request: BaseRequest = request self.response: BaseResponse = response self.recipe_id: str = recipe_id - self.config: AccountLinkingConfig = config + self.config = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index f325c9575..c1f35d6b9 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -50,7 +50,17 @@ def __init__( self.apis = apis -class AccountLinkingConfig: +class InputOverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class MultiFactorAuthConfig: def __init__( self, first_factors: Optional[List[str]], diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py new file mode 100644 index 000000000..b3b181916 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Union + + +if TYPE_CHECKING: + from .types import OverrideConfig, InputOverrideConfig, MultiFactorAuthConfig + + +def validate_and_normalise_user_input( + first_factors: Optional[List[str]], + override: Union[InputOverrideConfig, None] = None, +) -> MultiFactorAuthConfig: + if first_factors is not None and len(first_factors) == 0: + raise ValueError("'first_factors' can be either None or a non-empty list") + + if override is None: + override = InputOverrideConfig() + + return MultiFactorAuthConfig( + first_factors=first_factors, + override=OverrideConfig(functions=override.functions, apis=override.apis), + ) From 023830a9a16e5d92d575224f2be81eff0d7626b5 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 13 Aug 2024 15:03:39 +0530 Subject: [PATCH 017/126] more changes --- .../recipe/emailverification/recipe.py | 1 + .../recipe/multifactorauth/interfaces.py | 10 +- .../multi_factor_auth_claim.py | 251 ++++++++++++++++ .../recipe/multifactorauth/recipe.py | 106 +++++++ .../recipe/multifactorauth/types.py | 31 +- .../recipe/multifactorauth/utils.py | 275 +++++++++++++++++- .../recipe/multitenancy/interfaces.py | 4 + .../recipe/multitenancy/recipe.py | 2 + .../claim_base_classes/boolean_claim.py | 2 +- .../claim_base_classes/primitive_claim.py | 2 +- .../recipe/session/interfaces.py | 7 +- .../recipe/session/recipe_implementation.py | 1 + tests/sessions/claims/test_assert_claims.py | 4 +- .../claims/test_primitive_array_claim.py | 70 +++-- tests/sessions/claims/test_primitive_claim.py | 42 ++- 15 files changed, 753 insertions(+), 55 deletions(-) create mode 100644 supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py create mode 100644 supertokens_python/recipe/multifactorauth/recipe.py diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index d257026c4..77620b78c 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -477,6 +477,7 @@ async def fetch_value( _: str, recipe_user_id: RecipeUserId, __: str, + ___: Dict[str, Any], user_context: Dict[str, Any], ) -> bool: recipe = EmailVerificationRecipe.get_instance_or_throw() diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 22de2d2c0..47964e46b 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -39,7 +39,7 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( session: SessionContainer, factor_id: str, mfa_requirements_for_auth: MFARequirementList, - factors_set_up_for_user: Awaitable[Callable[[], List[str]]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], user_context: Dict[str, Any], ) -> None: pass @@ -50,10 +50,10 @@ async def get_mfa_requirements_for_auth( tenant_id: str, access_token_payload: Dict[str, Any], completed_factors: Dict[str, int], - user: Awaitable[Callable[[], AccountLinkingUser]], - factors_set_up_for_user: Awaitable[Callable[[], List[str]]], - required_secondary_factors_for_user: Awaitable[Callable[[], List[str]]], - required_secondary_factors_for_tenant: Awaitable[Callable[[], List[str]]], + user: Callable[[], Awaitable[AccountLinkingUser]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], user_context: Dict[str, Any], ) -> MFARequirementList: pass diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py new file mode 100644 index 000000000..625b6c3cc --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional, Set + +from supertokens_python.recipe.session.interfaces import ( + ClaimValidationResult, + JSONObject, + SessionClaim, + SessionClaimValidator, +) +from supertokens_python.types import RecipeUserId + +from .types import ( + FactorIdsAndType, + MFAClaimValue, + MFARequirementList, +) +from .utils import update_and_get_mfa_related_info_in_session + + +class HasCompletedRequirementListSCV(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + requirement_list: MFARequirementList, + ): + super().__init__(id_) + self.claim: MultiFactorAuthClaimClass = claim + self.requirement_list = requirement_list + + async def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + return ( + True + if self.claim.key not in payload or not payload[self.claim.key] + else False + ) + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + if len(self.requirement_list) == 0: + return ClaimValidationResult(is_valid=True) # no requirements to satisfy + + if (self.claim.key not in payload) or (not payload[self.claim.key]): + raise Exception( + "This should never happen, claim value not present in payload" + ) + + claim_val: MFAClaimValue = payload[self.claim.key] + + completed_factors = claim_val.c + next_set_of_unsatisfied_factors = ( + self.claim.get_next_set_of_unsatisfied_factors( + completed_factors, self.requirement_list + ) + ) + + if len(next_set_of_unsatisfied_factors.factor_ids) == 0: + return ClaimValidationResult( + is_valid=True + ) # No item in the requirementList is left unsatisfied, hence is Valid + + factor_ids = next_set_of_unsatisfied_factors.factor_ids + + if next_set_of_unsatisfied_factors.type == "string": + return ClaimValidationResult( + is_valid=False, + reason={ + "message": f"Factor validation failed: {factor_ids[0]} not completed", + "factor_id": factor_ids[0], + }, + ) + + elif next_set_of_unsatisfied_factors.type == "oneOf": + return ClaimValidationResult( + is_valid=False, + reason={ + "reason": f"None of these factors are complete in the session: {', '.join(factor_ids)}", + "one_of": factor_ids, + }, + ) + else: + return ClaimValidationResult( + is_valid=False, + reason={ + "reason": f"Some of the factors are not complete in the session: {', '.join(factor_ids)}", + "all_of_in_any_order": factor_ids, + }, + ) + + +class HasCompletedMFARequirementsForAuthSCV(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + ): + super().__init__(id_) + self.claim = claim + + async def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + assert self.claim is not None + return ( + True + if self.claim.key not in payload or not payload[self.claim.key] + else False + ) + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + assert self.claim is not None + if self.claim.key not in payload or not payload[self.claim.key]: + raise Exception( + "This should never happen, claim value not present in payload" + ) + claim_val: MFAClaimValue = payload[self.claim.key] + + return ClaimValidationResult( + is_valid=claim_val.v, + reason=( + { + "message": "MFA requirement for auth is not satisfied", + } + if not claim_val.v + else None + ), + ) + + +class MultiFactorAuthClaimValidators: + def __init__(self, claim: MultiFactorAuthClaimClass): + self.claim = claim + + def has_completed_requirement_list( + self, requirement_list: MFARequirementList, claim_key: Optional[str] = None + ) -> SessionClaimValidator: + return HasCompletedRequirementListSCV( + id_=claim_key or self.claim.key, + claim=self.claim, + requirement_list=requirement_list, + ) + + def has_completed_mfa_requirements_for_auth( + self, claim_key: Optional[str] = None + ) -> SessionClaimValidator: + + return HasCompletedMFARequirementsForAuthSCV( + id_=claim_key or self.claim.key, + claim=self.claim, + ) + + +class MultiFactorAuthClaimClass(SessionClaim[MFAClaimValue]): + def __init__(self, key: Optional[str] = None): + key = key or "st-mfa" + + async def fetch_value( + user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> MFAClaimValue: + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=None, + input_updated_factor_id=None, + input_session_recipe_user_id=recipe_user_id, + input_tenant_id=tenant_id, + input_access_token_payload=current_payload, + user_context=user_context, + ) + return MFAClaimValue( + c=mfa_info.completed_factors, + v=mfa_info.is_mfa_requirements_for_auth_satisfied, + ) + + super().__init__(key or "st-mfa", fetch_value=fetch_value) + self.validators = MultiFactorAuthClaimValidators(claim=self) + + def get_next_set_of_unsatisfied_factors( + self, completed_factors: Dict[str, int], requirement_list: MFARequirementList + ) -> FactorIdsAndType: + for req in requirement_list: + next_factors: Set[str] = set() + factor_type = "string" + + if isinstance(req, str): + if req not in completed_factors: + factor_type = "string" + next_factors.add(req) + else: + if "oneOf" in req: + satisfied = any( + factor_id in completed_factors for factor_id in req["oneOf"] + ) + if not satisfied: + factor_type = "oneOf" + next_factors.update(req["oneOf"]) + elif "allOfInAnyOrder" in req: + factor_type = "allOfInAnyOrder" + next_factors.update( + factor_id + for factor_id in req["allOfInAnyOrder"] + if factor_id not in completed_factors + ) + + if len(next_factors) > 0: + return FactorIdsAndType(factor_ids=list(next_factors), type=factor_type) + + return FactorIdsAndType(factor_ids=[], type="string") + + def add_to_payload_( + self, + payload: JSONObject, + value: MFAClaimValue, + user_context: Optional[Dict[str, Any]] = None, + ) -> JSONObject: + prev_value = payload.get(self.key, {}) + return { + **payload, + self.key: { + "c": {**prev_value.get("c", {}), **value.c}, + "v": value.v, + }, + } + + def remove_from_payload( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> JSONObject: + del payload[self.key] + return payload + + def remove_from_payload_by_merge_( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> JSONObject: + payload[self.key] = None + return payload + + def get_value_from_payload( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> Optional[MFAClaimValue]: + return payload.get(self.key) + + +MultiFactorAuthClaim = MultiFactorAuthClaimClass() diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py new file mode 100644 index 000000000..f0ecadd99 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from os import environ +from typing import Any, Dict, Optional, List, Union + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.multifactorauth.interfaces import RecipeInterface +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.supertokens import AppInfo +from .types import OverrideConfig + + +class MultiFactorAuthRecipe(RecipeModule): + recipe_id = "multifactorauth" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, + ): + super().__init__(recipe_id, app_info) + self.recipe_implementation: RecipeInterface + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ): + return None + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, + ): + def func(app_info: AppInfo): + if MultiFactorAuthRecipe.__instance is None: + MultiFactorAuthRecipe.__instance = MultiFactorAuthRecipe( + MultiFactorAuthRecipe.recipe_id, + app_info, + first_factors, + override, + ) + return MultiFactorAuthRecipe.__instance + raise_general_exception( + "MultiFactorAuthRecipe recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance_or_throw_error() -> MultiFactorAuthRecipe: + if MultiFactorAuthRecipe.__instance is not None: + return MultiFactorAuthRecipe.__instance + raise_general_exception( + "MultiFactorAuth recipe initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + MultiFactorAuthRecipe.__instance = None diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index c1f35d6b9..11b1672f6 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -14,6 +14,7 @@ from typing import Dict, Any, Union, List, Optional, Callable from .interfaces import RecipeInterface, APIInterface +from typing_extensions import Literal class MFARequirementList(List[Union[Dict[str, List[str]], str]]): @@ -50,16 +51,6 @@ def __init__( self.apis = apis -class InputOverrideConfig: - def __init__( - self, - functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, - apis: Union[Callable[[APIInterface], APIInterface], None] = None, - ): - self.functions = functions - self.apis = apis - - class MultiFactorAuthConfig: def __init__( self, @@ -68,3 +59,23 @@ def __init__( ): self.first_factors = first_factors self.override = override + + +class FactorIds: + EMAILPASSWORD: Literal["emailpassword"] = "emailpassword" + OTP_EMAIL: Literal["otp-email"] = "otp-email" + OTP_PHONE: Literal["otp-phone"] = "otp-phone" + LINK_EMAIL: Literal["link-email"] = "link-email" + LINK_PHONE: Literal["link-phone"] = "link-phone" + THIRDPARTY: Literal["thirdparty"] = "thirdparty" + TOTP: Literal["totp"] = "totp" + + +class FactorIdsAndType: + def __init__( + self, + factor_ids: List[str], + type: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]], + ): + self.factor_ids = factor_ids + self.type = type diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index b3b181916..2c75f6e52 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -14,23 +14,288 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Union - +from typing import Dict, Any, Union, List +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import get_session_information +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.types import FactorIds +from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe as Recipe, +) +from supertokens_python.recipe.multifactorauth.types import ( + MFAClaimValue, + MFARequirementList, +) +from supertokens_python.types import RecipeUserId +import math +import time +from typing_extensions import Literal +from supertokens_python.utils import log_debug_message +from ..multitenancy.interfaces import GetTenantOkResult if TYPE_CHECKING: - from .types import OverrideConfig, InputOverrideConfig, MultiFactorAuthConfig + from .types import OverrideConfig, MultiFactorAuthConfig def validate_and_normalise_user_input( first_factors: Optional[List[str]], - override: Union[InputOverrideConfig, None] = None, + override: Union[OverrideConfig, None] = None, ) -> MultiFactorAuthConfig: if first_factors is not None and len(first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") if override is None: - override = InputOverrideConfig() + override = OverrideConfig() return MultiFactorAuthConfig( first_factors=first_factors, - override=OverrideConfig(functions=override.functions, apis=override.apis), + override=override, + ) + + +class UpdateAndGetMFARelatedInfoInSessionResult: + def __init__( + self, + completed_factors: Dict[str, int], + mfa_requirements_for_auth: MFARequirementList, + is_mfa_requirements_for_auth_satisfied: bool, + ): + self.completed_factors = completed_factors + self.mfa_requirements_for_auth = mfa_requirements_for_auth + self.is_mfa_requirements_for_auth_satisfied = ( + is_mfa_requirements_for_auth_satisfied + ) + + +async def update_and_get_mfa_related_info_in_session( + input_session_recipe_user_id: Optional[RecipeUserId], + input_tenant_id: Optional[str], + input_access_token_payload: Optional[Dict[str, Any]], + input_session: Optional[SessionContainer], + input_updated_factor_id: Optional[str], + user_context: Dict[str, Any], +) -> UpdateAndGetMFARelatedInfoInSessionResult: + session_recipe_user_id: RecipeUserId + tenant_id: str + access_token_payload: Dict[str, Any] + session_handle: str + + if input_session is not None: + session_recipe_user_id = input_session.get_recipe_user_id(user_context) + tenant_id = input_session.get_tenant_id(user_context) + access_token_payload = input_session.get_access_token_payload(user_context) + session_handle = input_session.get_handle(user_context) + else: + assert input_session_recipe_user_id is not None + assert input_tenant_id is not None + assert input_access_token_payload is not None + session_recipe_user_id = input_session_recipe_user_id + tenant_id = input_tenant_id + access_token_payload = input_access_token_payload + session_handle = access_token_payload["sessionHandle"] + + updated_claim_val = False + mfa_claim_value = MultiFactorAuthClaim.get_value_from_payload(access_token_payload) + + if input_updated_factor_id is not None: + if mfa_claim_value is None: + updated_claim_val = True + mfa_claim_value = MFAClaimValue( + c={input_updated_factor_id: math.floor(time.time())}, + v=True, # updated later in the function + ) + else: + updated_claim_val = True + mfa_claim_value.c[input_updated_factor_id] = math.floor(time.time()) + + if mfa_claim_value is None: + session_user = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + session_recipe_user_id.get_as_string(), user_context + ) + ) + if session_user is None: + raise UnauthorisedError("Session user not found") + + session_info = await get_session_information(session_handle, user_context) + if session_info is None: + raise UnauthorisedError("Session not found") + + first_factor_time = session_info.time_created + computed_first_factor_id_for_session = None + + for login_method in session_user.login_methods: + if ( + login_method.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ): + if login_method.recipe_id == "emailpassword": + valid_res = await is_valid_first_factor( + tenant_id, FactorIds.EMAILPASSWORD, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = FactorIds.EMAILPASSWORD + break + elif login_method.recipe_id == "thirdparty": + valid_res = await is_valid_first_factor( + tenant_id, FactorIds.THIRDPARTY, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = FactorIds.THIRDPARTY + break + else: + factors_to_check: List[str] = [] + if login_method.email is not None: + factors_to_check.extend( + [FactorIds.LINK_EMAIL, FactorIds.OTP_EMAIL] + ) + if login_method.phone_number is not None: + factors_to_check.extend( + [FactorIds.LINK_PHONE, FactorIds.OTP_PHONE] + ) + + for factor_id in factors_to_check: + valid_res = await is_valid_first_factor( + tenant_id, factor_id, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = factor_id + break + + if computed_first_factor_id_for_session is not None: + break + + if computed_first_factor_id_for_session is None: + raise UnauthorisedError("Incorrect login method used") + + updated_claim_val = True + mfa_claim_value = MFAClaimValue( + c={computed_first_factor_id_for_session: first_factor_time}, + v=True, # updated later in this function + ) + + completed_factors = mfa_claim_value.c + + async def user_getter(): + resp = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + session_recipe_user_id.get_as_string(), user_context + ) + if resp is None: + raise UnauthorisedError("Session user not found") + return resp + + async def get_required_secondary_factors_for_tenant( + tenant_id: str, user_context: Dict[str, Any] + ) -> List[str]: + tenant_info = await get_tenant(tenant_id, user_context) + if tenant_info is None: + raise UnauthorisedError("Tenant not found") + return ( + tenant_info.required_secondary_factors + if tenant_info.required_secondary_factors is not None + else [] + ) + + mfa_requirements_for_auth = await Recipe.get_instance_or_throw_error().recipe_implementation.get_mfa_requirements_for_auth( + tenant_id=tenant_id, + access_token_payload=access_token_payload, + user=user_getter, + factors_set_up_for_user=lambda: Recipe.get_instance_or_throw_error().recipe_implementation.get_factors_setup_for_user( + user=(await user_getter()), user_context=user_context + ), + required_secondary_factors_for_user=lambda: Recipe.get_instance_or_throw_error().recipe_implementation.get_required_secondary_factors_for_user( + user_id=(await user_getter()).id, user_context=user_context + ), + required_secondary_factors_for_tenant=lambda: get_required_secondary_factors_for_tenant( + tenant_id, user_context + ), + completed_factors=completed_factors, + user_context=user_context, ) + + are_auth_reqs_complete = ( + len( + MultiFactorAuthClaim.get_next_set_of_unsatisfied_factors( + completed_factors, mfa_requirements_for_auth + ).factor_ids + ) + == 0 + ) + + if mfa_claim_value.v != are_auth_reqs_complete: + updated_claim_val = True + mfa_claim_value.v = are_auth_reqs_complete + + if input_session is not None and updated_claim_val: + await input_session.set_claim_value( + MultiFactorAuthClaim, mfa_claim_value, user_context + ) + + return UpdateAndGetMFARelatedInfoInSessionResult( + completed_factors=completed_factors, + mfa_requirements_for_auth=mfa_requirements_for_auth, + is_mfa_requirements_for_auth_satisfied=mfa_claim_value.v, + ) + + +async def is_valid_first_factor( + tenant_id: str, factor_id: str, user_context: Dict[str, Any] +) -> Literal["OK", "INVALID_FIRST_FACTOR_ERROR", "TENANT_NOT_FOUND_ERROR"]: + tenant_info = await get_tenant(tenant_id=tenant_id, user_context=user_context) + if tenant_info is None: + return "TENANT_NOT_FOUND_ERROR" + + tenant_config = tenant_info + + mt_recipe = MultitenancyRecipe.get_instance() + + first_factors_from_mfa = mt_recipe.static_first_factors + + log_debug_message( + f"is_valid_first_factor got {', '.join(tenant_config.first_factors) if tenant_config.first_factors else None} from tenant config" + ) + log_debug_message(f"is_valid_first_factor got {first_factors_from_mfa} from MFA") + + configured_first_factors: Union[List[str], None] = ( + tenant_config.first_factors or first_factors_from_mfa + ) + + if configured_first_factors is None: + configured_first_factors = mt_recipe.all_available_first_factors + + if is_factor_configured_for_tenant( + tenant_config=tenant_config, + all_available_first_factors=mt_recipe.all_available_first_factors, + first_factors=configured_first_factors, + factor_id=factor_id, + ): + return "OK" + + return "INVALID_FIRST_FACTOR_ERROR" + + +def is_factor_configured_for_tenant( + tenant_config: GetTenantOkResult, + all_available_first_factors: List[str], + first_factors: List[str], + factor_id: str, +) -> bool: + configured_first_factors = [ + f + for f in first_factors + if f in all_available_first_factors or f not in FactorIds.__dict__.values() + ] + + return factor_id in configured_first_factors diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 6bf03c9c8..cccc7ca1b 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -102,11 +102,15 @@ def __init__( passwordless: PasswordlessConfig, third_party: ThirdPartyConfig, core_config: Dict[str, Any], + first_factors: Optional[List[str]] = None, + required_secondary_factors: Optional[List[str]] = None, ): self.emailpassword = emailpassword self.passwordless = passwordless self.third_party = third_party self.core_config = core_config + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors class GetTenantOkResult(TenantConfigResponse): diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 9cc7aede8..5cfea3ce2 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -92,6 +92,8 @@ def __init__( ) RecipeModule.get_tenant_id = recipe_implementation.get_tenant_id + self.static_first_factors: Optional[List[str]] = None + self.all_available_first_factors: List[str] = [] def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, MultitenancyError) diff --git a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py index 95781b406..37af98f51 100644 --- a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py @@ -31,7 +31,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, RecipeUserId, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[bool]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py index e70d2663b..0ce0a2ceb 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py @@ -132,7 +132,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, RecipeUserId, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[Primitive]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 0f9b6f19c..389455841 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -629,7 +629,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, RecipeUserId, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[_T]], ], ) -> None: @@ -676,13 +676,16 @@ async def build( user_id: str, recipe_user_id: RecipeUserId, tenant_id: str, + current_payload: Dict[str, Any], user_context: Optional[Dict[str, Any]] = None, ) -> JSONObject: if user_context is None: user_context = {} value = await resolve( - self.fetch_value(user_id, recipe_user_id, tenant_id, user_context) + self.fetch_value( + user_id, recipe_user_id, tenant_id, current_payload, user_context + ) ) if value is None: diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 225622c75..9a5fd447d 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -132,6 +132,7 @@ async def validate_claims( user_id, recipe_user_id, access_token_payload.get("tId", DEFAULT_TENANT_ID), + access_token_payload, user_context, ) ) diff --git a/tests/sessions/claims/test_assert_claims.py b/tests/sessions/claims/test_assert_claims.py index cb6942b7d..c93718f30 100644 --- a/tests/sessions/claims/test_assert_claims.py +++ b/tests/sessions/claims/test_assert_claims.py @@ -123,7 +123,9 @@ async def validate( def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]): return False - dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___, ____: "Hello world") + dummy_claim = PrimitiveClaim( + "st-claim", lambda _, __, ___, ____, _____: "Hello world" + ) dummy_claim_validator = DummyClaimValidator(dummy_claim) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 47bd35b06..91c3c5ee5 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -121,7 +121,9 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) assert claim.get_last_refetch_time(payload) == timestamp @@ -143,7 +145,9 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.includes(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -156,7 +160,9 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.includes(included_item).validate(payload, {}) assert res.is_valid is True @@ -164,7 +170,9 @@ async def test_validator_should_validate_matching_payload(): async def test_should_not_validate_old_values(patch_get_timestamp_ms: MagicMock): claim = claim_with_inf_max_age - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -182,7 +190,9 @@ async def test_should_validate_old_values_if_max_age_is_none_and_default_is_inf( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -200,7 +210,9 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = claim_with_inf_max_age - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) assert ( await resolve( claim.validators.includes(excluded_item, 600).should_refetch(payload, {}) @@ -213,7 +225,9 @@ async def test_validator_should_refetch_if_value_is_old( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -230,7 +244,9 @@ async def test_validator_should_not_refetch_if_max_age_is_none_and_default_is_in patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -247,7 +263,9 @@ async def test_validator_should_validate_values_with_default_max_age( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 10 MINS: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -260,7 +278,9 @@ async def test_validator_should_not_refetch_if_max_age_overrides_to_inf( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -293,7 +313,9 @@ async def test_validator_excludes_should_not_validate_empty_payload(): async def test_validator_excludes_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.excludes(included_item).validate(payload, {}) assert res.is_valid is False @@ -306,7 +328,9 @@ async def test_validator_excludes_should_not_validate_mismatching_payload(): async def test_validator_excludes_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.excludes(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -329,7 +353,9 @@ async def test_validator_includes_all_should_not_validate_empty_payload(): async def test_validator_includes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.includes_all(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -342,7 +368,9 @@ async def test_validator_includes_all_should_not_validate_mismatching_payload(): async def test_validator_includes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.includes_all(included_item).validate(payload, {}) assert res.is_valid is True @@ -365,7 +393,9 @@ async def test_validator_excludes_all_should_not_validate_empty_payload(): async def test_validator_excludes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.excludes_all(included_item).validate(payload, {}) assert res.is_valid is False @@ -378,7 +408,9 @@ async def test_validator_excludes_all_should_not_validate_mismatching_payload(): async def test_validator_excludes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.excludes_all(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -387,8 +419,10 @@ async def test_validator_excludes_all_should_validate_matching_payload(): async def test_validator_should_not_validate_older_values_with_5min_default_max_age( patch_get_timestamp_ms: MagicMock, ): - claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + claim = PrimitiveArrayClaim("key", sync_fetch_value, 3000) # 5 mins + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 10 MINS: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 515716e06..315e2b153 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -90,7 +90,9 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) assert claim.get_last_refetch_time(payload) == timestamp @@ -112,7 +114,9 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.has_value(val2).validate(payload, {}) assert res.is_valid is False @@ -125,7 +129,9 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.has_value(val).validate(payload, {}) assert res.is_valid is True @@ -133,7 +139,9 @@ async def test_validator_should_validate_matching_payload(): async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -151,7 +159,9 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) assert ( await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) is False @@ -175,7 +185,9 @@ async def test_should_not_validate_empty_payload(): async def test_has_fresh_value_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.has_value(val2, 600).validate(payload, {}) assert res.is_valid is False assert res.reason == { @@ -187,7 +199,9 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload(): async def test_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) res = await claim.validators.has_value(val, 600).validate(payload, {}) assert res.is_valid is True @@ -197,7 +211,9 @@ async def test_should_not_validate_old_values_as_well( ): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -214,14 +230,14 @@ async def test_should_refetch_if_value_is_not_set(): async def test_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) assert claim.validators.has_value(val2, 600).should_refetch(payload, {}) is False async def test_should_refetch_if_value_is_old(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), DEFAULT_TENANT_ID) + payload = await claim.build("userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -233,7 +249,9 @@ async def test_should_not_validate_old_values_as_well_with_default_max_age_provi patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -251,7 +269,7 @@ async def test_should_refetch_if_value_is_old_with_default_max_age_provided( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("userId", RecipeUserId("userId"), "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore From c746310b35ae565c3207e8834e342956201336bb Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 16 Aug 2024 19:50:41 +0530 Subject: [PATCH 018/126] adds recipe impl for mfa recipe --- .../recipe/multifactorauth/interfaces.py | 2 +- .../multi_factor_auth_claim.py | 2 - .../recipe/multifactorauth/recipe.py | 139 ++++++++++- .../multifactorauth/recipe_implementation.py | 227 ++++++++++++++++++ .../recipe/multifactorauth/types.py | 82 ++++++- .../recipe/multifactorauth/utils.py | 10 +- .../recipe/multitenancy/interfaces.py | 4 + 7 files changed, 453 insertions(+), 13 deletions(-) create mode 100644 supertokens_python/recipe/multifactorauth/recipe_implementation.py diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 47964e46b..146c0fde5 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -38,7 +38,7 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( self, session: SessionContainer, factor_id: str, - mfa_requirements_for_auth: MFARequirementList, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], factors_set_up_for_user: Callable[[], Awaitable[List[str]]], user_context: Dict[str, Any], ) -> None: diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 625b6c3cc..4545f3c63 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -168,8 +168,6 @@ async def fetch_value( user_context: Dict[str, Any], ) -> MFAClaimValue: mfa_info = await update_and_get_mfa_related_info_in_session( - input_session=None, - input_updated_factor_id=None, input_session_recipe_user_id=recipe_user_id, input_tenant_id=tenant_id, input_access_token_payload=current_payload, diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index f0ecadd99..367af1ce5 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -15,14 +15,33 @@ from os import environ from typing import Any, Dict, Optional, List, Union +from typing_extensions import Literal from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.recipe.multifactorauth.interfaces import RecipeInterface +from supertokens_python.post_init_callbacks import PostSTInitCallbacks +from supertokens_python.querier import Querier +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo -from .types import OverrideConfig +from supertokens_python.types import AccountLinkingUser, RecipeUserId +from .types import ( + OverrideConfig, + GetFactorsSetupForUserFromOtherRecipesFunc, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetPhoneNumbersForFactorsFromOtherRecipesFunc, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, +) +from .utils import validate_and_normalise_user_input +from .recipe_implementation import RecipeImplementation class MultiFactorAuthRecipe(RecipeModule): @@ -37,7 +56,43 @@ def __init__( override: Union[OverrideConfig, None] = None, ): super().__init__(recipe_id, app_info) - self.recipe_implementation: RecipeInterface + self.get_factors_setup_for_user_from_other_recipes_funcs: List[ + GetFactorsSetupForUserFromOtherRecipesFunc + ] = [] + self.get_all_available_secondary_factor_ids_from_other_recipes_funcs: List[ + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc + ] = [] + self.get_emails_for_factor_from_other_recipes_funcs: List[ + GetEmailsForFactorFromOtherRecipesFunc + ] = [] + self.get_phone_numbers_for_factor_from_other_recipes_funcs: List[ + GetPhoneNumbersForFactorsFromOtherRecipesFunc + ] = [] + self.is_get_mfa_requirements_for_auth_overridden: bool = False + + self.config = validate_and_normalise_user_input( + first_factors, + override, + ) + + recipe_implementation = RecipeImplementation( + Querier.get_instance(recipe_id), self + ) + self.recipe_implementation = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + + def callback(): + mt_recipe = MultitenancyRecipe.get_instance() + mt_recipe.static_first_factors = self.config.first_factors + + SessionRecipe.get_instance().add_claim_validator_from_other_recipe( + MultiFactorAuthClaim.validators.has_completed_mfa_requirements_for_auth() + ) + + PostSTInitCallbacks.add_post_init_callback(callback) def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return False @@ -104,3 +159,81 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") MultiFactorAuthRecipe.__instance = None + + def add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + self, func: GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc + ): + self.get_all_available_secondary_factor_ids_from_other_recipes_funcs.append( + func + ) + + async def get_all_available_secondary_factor_ids( + self, tenant_config: TenantConfig + ) -> List[str]: + factor_ids: List[str] = [] + for ( + func + ) in self.get_all_available_secondary_factor_ids_from_other_recipes_funcs: + factor_ids_res = await func.func(tenant_config) + for factor_id in factor_ids_res: + if factor_id not in factor_ids: + factor_ids.append(factor_id) + return factor_ids + + def add_func_to_get_factors_setup_for_user_from_other_recipes( + self, func: GetFactorsSetupForUserFromOtherRecipesFunc + ): + self.get_factors_setup_for_user_from_other_recipes_funcs.append(func) + + def add_func_to_get_emails_for_factor_from_other_recipes( + self, func: GetEmailsForFactorFromOtherRecipesFunc + ): + self.get_emails_for_factor_from_other_recipes_funcs.append(func) + + async def get_emails_for_factors( + self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + ) -> Union[ + Dict[ + Literal["status", "factorIdToEmailsMap"], + Union[Literal["OK"], Dict[str, List[str]]], + ], + Dict[Literal["status"], Literal["UNKNOWN_SESSION_RECIPE_USER_ID"]], + ]: + + factorIdToEmailsMap: Dict[str, List[str]] = {} + + for func in self.get_emails_for_factor_from_other_recipes_funcs: + func_result = await func.func(user, session_recipe_user_id) + if isinstance( + func_result, GetEmailsForFactorUnknownSessionRecipeUserIdResult + ): + return {"status": "UNKNOWN_SESSION_RECIPE_USER_ID"} + factorIdToEmailsMap.update(func_result.factor_id_to_emails_map) + + return {"status": "OK", "factorIdToEmailsMap": factorIdToEmailsMap} + + def add_func_to_get_phone_numbers_for_factors_from_other_recipes( + self, func: GetPhoneNumbersForFactorsFromOtherRecipesFunc + ): + self.get_phone_numbers_for_factor_from_other_recipes_funcs.append(func) + + async def get_phone_numbers_for_factors( + self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + ) -> Union[ + Dict[ + Literal["status", "factorIdToPhoneNumberMap"], + Union[Literal["OK"], Dict[str, List[str]]], + ], + Dict[Literal["status"], Literal["UNKNOWN_SESSION_RECIPE_USER_ID"]], + ]: + factorIdToPhoneNumberMap: Dict[str, List[str]] = {} + + for func in self.get_phone_numbers_for_factor_from_other_recipes_funcs: + func_result = await func.func(user, session_recipe_user_id) + if isinstance( + func_result, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult + ): + return {"status": "UNKNOWN_SESSION_RECIPE_USER_ID"} + factorIdToPhoneNumberMap.update(func_result.factor_id_to_phone_number_map) + + return {"status": "OK", "factorIdToPhoneNumberMap": factorIdToPhoneNumberMap} diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py new file mode 100644 index 000000000..89797a17e --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Set, Callable, List + +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + MultiFactorAuthClaimClass, +) +from supertokens_python.recipe.session.interfaces import ( + ClaimValidationResult, + JSONObject, + SessionClaimValidator, +) +from supertokens_python.recipe.usermetadata.asyncio import ( + get_user_metadata, + update_user_metadata, +) +from supertokens_python.recipe.multifactorauth.types import ( + MFAClaimValue, + MFARequirementList, +) +from supertokens_python.recipe.session import SessionContainer + +from .interfaces import RecipeInterface + +from .recipe import MultiFactorAuthRecipe +from supertokens_python.types import AccountLinkingUser +from .utils import update_and_get_mfa_related_info_in_session + + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + + +class Validator(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + mfa_requirement_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + factor_id: str, + ): + super().__init__(id_) + self.claim: MultiFactorAuthClaimClass = claim + self.factors_set_up_for_user = factors_set_up_for_user + self.factor_id = factor_id + self.mfa_requirement_for_auth = mfa_requirement_for_auth + + async def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + return True if self.claim.get_value_from_payload(payload) is None else False + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + claim_val: MFAClaimValue | None = self.claim.get_value_from_payload(payload) + + if claim_val is None: + raise Exception( + "This should never happen, claim value not present in payload" + ) + + if claim_val.v: + # Session already satisfied auth requirements + return ClaimValidationResult(is_valid=True) + + set_of_unsatisfied_factors = self.claim.get_next_set_of_unsatisfied_factors( + claim_val.c, await self.mfa_requirement_for_auth() + ) + + factors_set_up_for_user = await self.factors_set_up_for_user() + + if any( + factor_id in factors_set_up_for_user + for factor_id in set_of_unsatisfied_factors.factor_ids + ): + return ClaimValidationResult( + is_valid=False, + reason={ + "message": "Completed factors in the session does not satisfy the MFA requirements for auth", + }, + ) + + if ( + set_of_unsatisfied_factors.factor_ids + and self.factor_id not in set_of_unsatisfied_factors.factor_ids + ): + return ClaimValidationResult( + is_valid=False, + reason={ + "message": "Not allowed to setup factor that is not in the next set of unsatisfied factors", + }, + ) + + return ClaimValidationResult(is_valid=True) + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + recipe_instance: MultiFactorAuthRecipe, + ): + super().__init__() + self.querier = querier + self.recipe_instance = recipe_instance + + async def get_factors_setup_for_user( + self, user: AccountLinkingUser, user_context: Dict[str, Any] + ) -> List[str]: + factor_ids: List[str] = [] + for ( + func + ) in self.recipe_instance.get_factors_setup_for_user_from_other_recipes_funcs: + result = await func.func(user, user_context) + for factor_id in result: + if factor_id not in factor_ids: + factor_ids.append(factor_id) + return factor_ids + + async def get_mfa_requirements_for_auth( + self, + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[AccountLinkingUser]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + all_factors: Set[str] = set() + for factor in await required_secondary_factors_for_user(): + all_factors.add(factor) + for factor in await required_secondary_factors_for_tenant(): + all_factors.add(factor) + return MFARequirementList({"oneOf": list(all_factors)}) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + self, + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + await session.assert_claims( + [ + Validator( + id_=MultiFactorAuthClaim.key, + claim=MultiFactorAuthClaim, + mfa_requirement_for_auth=mfa_requirements_for_auth, + factors_set_up_for_user=factors_set_up_for_user, + factor_id=factor_id, + ) + ], + user_context, + ) + + async def mark_factor_as_complete_in_session( + self, session: SessionContainer, factor_id: str, user_context: Dict[str, Any] + ): + await update_and_get_mfa_related_info_in_session( + input_session=session, + input_updated_factor_id=factor_id, + user_context=user_context, + ) + + async def get_required_secondary_factors_for_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> List[str]: + metadata = await get_user_metadata(user_id, user_context) + result: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + return result + + async def add_to_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ): + metadata = await get_user_metadata(user_id, user_context) + factor_ids: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + if factor_id not in factor_ids: + factor_ids.append(factor_id) + metadata_update = { + **metadata.metadata, + "_supertokens": { + **metadata.metadata.get("_supertokens", {}), + "requiredSecondaryFactors": factor_ids, + }, + } + await update_user_metadata(user_id, metadata_update, user_context) + + async def remove_from_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ): + metadata = await get_user_metadata(user_id, user_context) + factor_ids: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + if factor_id in factor_ids: + factor_ids = [id for id in factor_ids if id != factor_id] + metadata_update = { + **metadata.metadata, + "_supertokens": { + **metadata.metadata.get("_supertokens", {}), + "requiredSecondaryFactors": factor_ids, + }, + } + await update_user_metadata(user_id, metadata_update, user_context) diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 11b1672f6..0a82e61c2 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -12,13 +12,21 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Dict, Any, Union, List, Optional, Callable +from typing import Awaitable, Dict, Any, Union, List, Optional, Callable + +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from .interfaces import RecipeInterface, APIInterface from typing_extensions import Literal +from supertokens_python.types import AccountLinkingUser, RecipeUserId class MFARequirementList(List[Union[Dict[str, List[str]], str]]): - def __init__(self, *args: Union[str, Dict[str, List[str]]]): + def __init__( + self, + *args: Union[ + str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]] + ] + ): super().__init__() for arg in args: if isinstance(arg, str): @@ -79,3 +87,73 @@ def __init__( ): self.factor_ids = factor_ids self.type = type + + +class GetFactorsSetupForUserFromOtherRecipesFunc: + def __init__( + self, + func: Callable[[AccountLinkingUser, Dict[str, Any]], Awaitable[List[str]]], + ): + self.func = func + + +class GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc: + def __init__( + self, + func: Callable[[TenantConfig], Awaitable[List[str]]], + ): + self.func = func + + +class GetEmailsForFactorOkResult: + status: Literal["OK"] = "OK" + + def __init__(self, factor_id_to_emails_map: Dict[str, List[str]]): + self.factor_id_to_emails_map = factor_id_to_emails_map + + +class GetEmailsForFactorUnknownSessionRecipeUserIdResult: + status: Literal["UNKNOWN_SESSION_RECIPE_USER_ID"] = "UNKNOWN_SESSION_RECIPE_USER_ID" + + +class GetEmailsForFactorFromOtherRecipesFunc: + def __init__( + self, + func: Callable[ + [AccountLinkingUser, RecipeUserId], + Awaitable[ + Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ] + ], + ], + ): + self.func = func + + +class GetPhoneNumbersForFactorsOkResult: + status: Literal["OK"] = "OK" + + def __init__(self, factor_id_to_phone_number_map: Dict[str, List[str]]): + self.factor_id_to_phone_number_map = factor_id_to_phone_number_map + + +class GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult: + status: Literal["UNKNOWN_SESSION_RECIPE_USER_ID"] = "UNKNOWN_SESSION_RECIPE_USER_ID" + + +class GetPhoneNumbersForFactorsFromOtherRecipesFunc: + def __init__( + self, + func: Callable[ + [AccountLinkingUser, RecipeUserId], + Awaitable[ + Union[ + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + ] + ], + ], + ): + self.func = func diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 2c75f6e52..a301bd85e 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -74,12 +74,12 @@ def __init__( async def update_and_get_mfa_related_info_in_session( - input_session_recipe_user_id: Optional[RecipeUserId], - input_tenant_id: Optional[str], - input_access_token_payload: Optional[Dict[str, Any]], - input_session: Optional[SessionContainer], - input_updated_factor_id: Optional[str], user_context: Dict[str, Any], + input_session_recipe_user_id: Optional[RecipeUserId] = None, + input_tenant_id: Optional[str] = None, + input_access_token_payload: Optional[Dict[str, Any]] = None, + input_session: Optional[SessionContainer] = None, + input_updated_factor_id: Optional[str] = None, ) -> UpdateAndGetMFARelatedInfoInSessionResult: session_recipe_user_id: RecipeUserId tenant_id: str diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index cccc7ca1b..1e02d3b34 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -34,11 +34,15 @@ def __init__( passwordless_enabled: Union[bool, None] = None, third_party_enabled: Union[bool, None] = None, core_config: Union[Dict[str, Any], None] = None, + first_factors: Optional[List[str]] = None, + required_secondary_factors: Optional[List[str]] = None, ): self.email_password_enabled = email_password_enabled self.passwordless_enabled = passwordless_enabled self.third_party_enabled = third_party_enabled self.core_config = core_config + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors def to_json(self) -> Dict[str, Any]: res: Dict[str, Any] = {} From 57ce38b7693055187ed32ae09dc5f3545c49f0fc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 16 Aug 2024 20:58:06 +0530 Subject: [PATCH 019/126] adds mfa recipe --- .../recipe/dashboard/api/list_tenants.py | 7 +- .../recipe/dashboard/interfaces.py | 34 +--- .../recipe/multifactorauth/api/__init__.py | 15 ++ .../multifactorauth/api/implementation.py | 153 +++++++++++++++++ .../api/resync_session_and_fetch_mfa_info.py | 52 ++++++ .../multifactorauth/asyncio/__init__.py | 157 ++++++++++++++++++ .../recipe/multifactorauth/constants.py | 14 ++ .../recipe/multifactorauth/interfaces.py | 26 ++- .../recipe/multifactorauth/recipe.py | 66 ++++++-- .../recipe/multifactorauth/syncio/__init__.py | 130 +++++++++++++++ .../recipe/multifactorauth/utils.py | 3 - .../recipe/multitenancy/api/implementation.py | 8 +- .../recipe/multitenancy/asyncio/__init__.py | 3 +- .../recipe/multitenancy/interfaces.py | 104 ++---------- .../multitenancy/recipe_implementation.py | 41 ++--- .../thirdparty/recipe_implementation.py | 2 +- tests/multitenancy/test_tenants_crud.py | 40 ++--- tests/test-server/multitenancy.py | 60 +++---- 18 files changed, 686 insertions(+), 229 deletions(-) create mode 100644 supertokens_python/recipe/multifactorauth/api/__init__.py create mode 100644 supertokens_python/recipe/multifactorauth/api/implementation.py create mode 100644 supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py create mode 100644 supertokens_python/recipe/multifactorauth/asyncio/__init__.py create mode 100644 supertokens_python/recipe/multifactorauth/constants.py create mode 100644 supertokens_python/recipe/multifactorauth/syncio/__init__.py diff --git a/supertokens_python/recipe/dashboard/api/list_tenants.py b/supertokens_python/recipe/dashboard/api/list_tenants.py index 6d519c101..810ee4c6b 100644 --- a/supertokens_python/recipe/dashboard/api/list_tenants.py +++ b/supertokens_python/recipe/dashboard/api/list_tenants.py @@ -42,12 +42,7 @@ async def handle_list_tenants_api( final_tenants: List[DashboardListTenantItem] = [] for current_tenant in tenants.tenants: - dashboard_tenant = DashboardListTenantItem( - tenant_id=current_tenant.tenant_id, - emailpassword=current_tenant.emailpassword, - passwordless=current_tenant.passwordless, - third_party=current_tenant.third_party, - ) + dashboard_tenant = DashboardListTenantItem(current_tenant) final_tenants.append(dashboard_tenant) return DashboardListTenantsGetResponse(final_tenants) diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index dbf3727ce..2e93b40f5 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.types import AccountLinkingUser @@ -27,12 +28,6 @@ from supertokens_python.recipe.session.interfaces import SessionInformationResult from supertokens_python.framework import BaseRequest, BaseResponse - from supertokens_python.recipe.multitenancy.interfaces import ( - EmailPasswordConfig, - PasswordlessConfig, - ThirdPartyConfig, - ) - class SessionInfo: def __init__(self, info: SessionInformationResult) -> None: @@ -109,28 +104,17 @@ def to_json(self) -> Dict[str, Any]: class DashboardListTenantItem: - def __init__( - self, - tenant_id: str, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - ): - self.tenant_id = tenant_id - self.emailpassword = emailpassword - self.passwordless = passwordless - self.third_party = third_party + def __init__(self, tenant_config: TenantConfig): + self.tenant_config = tenant_config - def to_json(self): - res = { - "tenantId": self.tenant_id, - "emailPassword": self.emailpassword.to_json(), - "passwordless": self.passwordless.to_json(), - "thirdParty": self.third_party.to_json(), + def to_json(self) -> Dict[str, Any]: + return { + "tenantId": self.tenant_config.tenant_id, + "emailPassword": {"enabled": self.tenant_config.email_password_enabled}, + "passwordless": {"enabled": self.tenant_config.passwordless_enabled}, + "thirdParty": {"enabled": self.tenant_config.third_party_enabled}, } - return res - class DashboardListTenantsGetResponse(APIResponse): status: str = "OK" diff --git a/supertokens_python/recipe/multifactorauth/api/__init__.py b/supertokens_python/recipe/multifactorauth/api/__init__.py new file mode 100644 index 000000000..13214e8aa --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from .resync_session_and_fetch_mfa_info import handle_resync_session_and_fetch_mfa_info_api # type: ignore diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py new file mode 100644 index 000000000..9f3dd774b --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.multifactorauth.utils import ( + update_and_get_mfa_related_info_in_session, +) +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from ..multi_factor_auth_claim import MultiFactorAuthClaim +from supertokens_python.asyncio import get_user +from supertokens_python.recipe.session.exceptions import ( + InvalidClaimsError, + SuperTokensSessionError, + UnauthorisedError, +) + +if TYPE_CHECKING: + from supertokens_python.recipe.multifactorauth.interfaces import ( + APIInterface, + APIOptions, + ) + +from supertokens_python.types import GeneralErrorResponse +from ..interfaces import ( + APIInterface, + APIOptions, + NextFactors, + ResyncSessionAndFetchMFAInfoPUTOkResult, +) + + +class APIImplementation(APIInterface): + async def resync_session_and_fetch_mfa_info_put( + self, + api_options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + session_user = await get_user(session.get_user_id(), user_context) + + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + factors_setup_for_user = ( + await api_options.recipe_implementation.get_factors_setup_for_user( + user=session_user, + user_context=user_context, + ) + ) + tenant_info = await get_tenant( + session.get_tenant_id(user_context), user_context + ) + if tenant_info is None: + raise UnauthorisedError( + "Tenant not found", + ) + all_available_secondary_factors = ( + await api_options.recipe_instance.get_all_available_secondary_factor_ids( + tenant_info + ) + ) + + factors_allowed_to_setup: List[str] = [] + + async def get_factors_set_up_for_user(): + return factors_setup_for_user + + async def get_mfa_requirements_for_auth(): + return mfa_info.mfa_requirements_for_auth + + for factor_id in all_available_secondary_factors: + try: + await api_options.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=factor_id, + factors_set_up_for_user=get_factors_set_up_for_user, + mfa_requirements_for_auth=get_mfa_requirements_for_auth, + user_context=user_context, + ) + factors_allowed_to_setup.append(factor_id) + except SuperTokensSessionError as err: + if not isinstance(err, InvalidClaimsError): + raise err + + next_set_of_unsatisfied_factors = ( + MultiFactorAuthClaim.get_next_set_of_unsatisfied_factors( + mfa_info.completed_factors, mfa_info.mfa_requirements_for_auth + ) + ) + + get_emails_for_factors_result = ( + await api_options.recipe_instance.get_emails_for_factors( + session_user, session.get_recipe_user_id(user_context) + ) + ) + get_phone_numbers_for_factors_result = ( + await api_options.recipe_instance.get_phone_numbers_for_factors( + session_user, session.get_recipe_user_id(user_context) + ) + ) + if ( + get_emails_for_factors_result.status == "UNKNOWN_SESSION_RECIPE_USER_ID" + or get_phone_numbers_for_factors_result.status + == "UNKNOWN_SESSION_RECIPE_USER_ID" + ): + raise UnauthorisedError( + "User no longer associated with the session", + ) + + next_factors = [ + factor_id + for factor_id in next_set_of_unsatisfied_factors.factor_ids + if factor_id in factors_allowed_to_setup + or factor_id in factors_setup_for_user + ] + + if ( + len(next_factors) == 0 + and len(next_set_of_unsatisfied_factors.factor_ids) != 0 + ): + raise Exception( + f"The user is required to complete secondary factors they are not allowed to " + f"({', '.join(next_set_of_unsatisfied_factors.factor_ids)}), likely because of configuration issues." + ) + return ResyncSessionAndFetchMFAInfoPUTOkResult( + factors=NextFactors( + next=next_factors, + already_setup=factors_setup_for_user, + allowed_to_setup=factors_allowed_to_setup, + ), + emails=get_emails_for_factors_result.factor_id_to_emails_map, + phone_numbers=get_phone_numbers_for_factors_result.factor_id_to_phone_number_map, + ) diff --git a/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py new file mode 100644 index 000000000..1048253e6 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + +from supertokens_python.recipe.session.asyncio import get_session + + +async def handle_resync_session_and_fetch_mfa_info_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_resync_session_and_fetch_mfa_info_put is True: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + assert session is not None + + response = await api_implementation.resync_session_and_fetch_mfa_info_put( + api_options, + session, + user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py new file mode 100644 index 000000000..9bdbc5e36 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Optional, List + +from supertokens_python.recipe.session import SessionContainer + +from ..interfaces import ( + MFARequirementList, +) +from ..recipe import MultiFactorAuthRecipe +from ..utils import update_and_get_mfa_related_info_in_session +from supertokens_python.recipe.accountlinking.asyncio import get_user + + +async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + factors_set_up_for_user = await get_factors_setup_for_user( + session.get_user_id(), user_context + ) + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + + async def func_factors_set_up_for_user(): + return factors_set_up_for_user + + async def func_mfa_requirements_for_auth(): + return mfa_info.mfa_requirements_for_auth + + await recipe.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=factor_id, + factors_set_up_for_user=func_factors_set_up_for_user, + mfa_requirements_for_auth=func_mfa_requirements_for_auth, + user_context=user_context, + ) + + +async def get_mfa_requirements_for_auth( + session: SessionContainer, + user_context: Optional[Dict[str, Any]] = None, +) -> MFARequirementList: + if user_context is None: + user_context = {} + + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + + return mfa_info.mfa_requirements_for_auth + + +async def mark_factor_as_complete_in_session( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id=factor_id, + user_context=user_context, + ) + + +async def get_factors_setup_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + user = await get_user(user_id, user_context) + if user is None: + raise Exception("Unknown user id") + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + return await recipe.recipe_implementation.get_factors_setup_for_user( + user=user, + user_context=user_context, + ) + + +async def get_required_secondary_factors_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + return await recipe.recipe_implementation.get_required_secondary_factors_for_user( + user_id=user_id, + user_context=user_context, + ) + + +async def add_to_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.add_to_required_secondary_factors_for_user( + user_id=user_id, + factor_id=factor_id, + user_context=user_context, + ) + + +async def remove_from_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.remove_from_required_secondary_factors_for_user( + user_id=user_id, + factor_id=factor_id, + user_context=user_context, + ) + + +init = MultiFactorAuthRecipe.init diff --git a/supertokens_python/recipe/multifactorauth/constants.py b/supertokens_python/recipe/multifactorauth/constants.py new file mode 100644 index 000000000..5f7a38023 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/constants.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +RESYNC_SESSION_AND_FETCH_MFA_INFO = "/mfa/info" diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 146c0fde5..ea8ec20b4 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -16,6 +16,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any, Union, List, Callable, Awaitable +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.types import AccountLinkingUser @@ -24,7 +25,7 @@ from ...supertokens import AppInfo -from ...types import GeneralErrorResponse +from ...types import APIResponse, GeneralErrorResponse if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse @@ -101,6 +102,7 @@ def __init__( config: MultiFactorAuthConfig, recipe_implementation: RecipeInterface, app_info: AppInfo, + recipe_instance: MultiFactorAuthRecipe, ): self.request: BaseRequest = request self.response: BaseResponse = response @@ -108,6 +110,7 @@ def __init__( self.config = config self.recipe_implementation: RecipeInterface = recipe_implementation self.app_info = app_info + self.recipe_instance = recipe_instance class APIInterface: @@ -132,16 +135,31 @@ def __init__( self.already_setup = already_setup self.allowed_to_setup = allowed_to_setup + def to_json(self) -> Dict[str, Any]: + return { + "next": self.next, + "alreadySetup": self.already_setup, + "allowedToSetup": self.allowed_to_setup, + } -class ResyncSessionAndFetchMFAInfoPUTOkResult: + +class ResyncSessionAndFetchMFAInfoPUTOkResult(APIResponse): def __init__( self, factors: NextFactors, - emails: Dict[str, Union[List[str], None]], - phone_numbers: Dict[str, Union[List[str], None]], + emails: Dict[str, List[str]], + phone_numbers: Dict[str, List[str]], ): self.factors = factors self.emails = emails self.phone_numbers = phone_numbers status: str = "OK" + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "factors": self.factors.to_json(), + "emails": self.emails, + "phoneNumbers": self.phone_numbers, + } diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index 367af1ce5..c8882b07c 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -15,13 +15,18 @@ from os import environ from typing import Any, Dict, Optional, List, Union -from typing_extensions import Literal from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.querier import Querier +from supertokens_python.recipe.multifactorauth.api import ( + resync_session_and_fetch_mfa_info, +) +from supertokens_python.recipe.multifactorauth.constants import ( + RESYNC_SESSION_AND_FETCH_MFA_INFO, +) from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( MultiFactorAuthClaim, ) @@ -39,9 +44,13 @@ GetPhoneNumbersForFactorsFromOtherRecipesFunc, GetEmailsForFactorUnknownSessionRecipeUserIdResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + GetEmailsForFactorOkResult, + GetPhoneNumbersForFactorsOkResult, ) from .utils import validate_and_normalise_user_input from .recipe_implementation import RecipeImplementation +from .api.implementation import APIImplementation +from .interfaces import APIOptions class MultiFactorAuthRecipe(RecipeModule): @@ -84,6 +93,13 @@ def __init__( else self.config.override.functions(recipe_implementation) ) + api_implementation = APIImplementation() + self.api_implementation = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) + def callback(): mt_recipe = MultitenancyRecipe.get_instance() mt_recipe.static_first_factors = self.config.first_factors @@ -98,7 +114,16 @@ def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return False def get_apis_handled(self) -> List[APIHandled]: - return [] + return [ + APIHandled( + method="put", + path_without_api_base_path=NormalisedURLPath( + RESYNC_SESSION_AND_FETCH_MFA_INFO + ), + request_id=RESYNC_SESSION_AND_FETCH_MFA_INFO, + disabled=self.api_implementation.disable_resync_session_and_fetch_mfa_info_put, + ) + ] async def handle_api_request( self, @@ -110,7 +135,18 @@ async def handle_api_request( response: BaseResponse, user_context: Dict[str, Any], ): - return None + options = APIOptions( + request, + response, + self.get_recipe_id(), + self.config, + self.recipe_implementation, + self.get_app_info(), + self, + ) + return await resync_session_and_fetch_mfa_info.handle_resync_session_and_fetch_mfa_info_api( + tenant_id, self.api_implementation, options, user_context + ) async def handle_error( self, @@ -193,11 +229,8 @@ def add_func_to_get_emails_for_factor_from_other_recipes( async def get_emails_for_factors( self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId ) -> Union[ - Dict[ - Literal["status", "factorIdToEmailsMap"], - Union[Literal["OK"], Dict[str, List[str]]], - ], - Dict[Literal["status"], Literal["UNKNOWN_SESSION_RECIPE_USER_ID"]], + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, ]: factorIdToEmailsMap: Dict[str, List[str]] = {} @@ -207,10 +240,10 @@ async def get_emails_for_factors( if isinstance( func_result, GetEmailsForFactorUnknownSessionRecipeUserIdResult ): - return {"status": "UNKNOWN_SESSION_RECIPE_USER_ID"} + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() factorIdToEmailsMap.update(func_result.factor_id_to_emails_map) - return {"status": "OK", "factorIdToEmailsMap": factorIdToEmailsMap} + return GetEmailsForFactorOkResult(factor_id_to_emails_map=factorIdToEmailsMap) def add_func_to_get_phone_numbers_for_factors_from_other_recipes( self, func: GetPhoneNumbersForFactorsFromOtherRecipesFunc @@ -220,11 +253,8 @@ def add_func_to_get_phone_numbers_for_factors_from_other_recipes( async def get_phone_numbers_for_factors( self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId ) -> Union[ - Dict[ - Literal["status", "factorIdToPhoneNumberMap"], - Union[Literal["OK"], Dict[str, List[str]]], - ], - Dict[Literal["status"], Literal["UNKNOWN_SESSION_RECIPE_USER_ID"]], + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, ]: factorIdToPhoneNumberMap: Dict[str, List[str]] = {} @@ -233,7 +263,9 @@ async def get_phone_numbers_for_factors( if isinstance( func_result, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult ): - return {"status": "UNKNOWN_SESSION_RECIPE_USER_ID"} + return GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult() factorIdToPhoneNumberMap.update(func_result.factor_id_to_phone_number_map) - return {"status": "OK", "factorIdToPhoneNumberMap": factorIdToPhoneNumberMap} + return GetPhoneNumbersForFactorsOkResult( + factor_id_to_phone_number_map=factorIdToPhoneNumberMap + ) diff --git a/supertokens_python/recipe/multifactorauth/syncio/__init__.py b/supertokens_python/recipe/multifactorauth/syncio/__init__.py new file mode 100644 index 000000000..c12a6ef4a --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/syncio/__init__.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Optional, List + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.async_to_sync_wrapper import sync + +from ..interfaces import ( + MFARequirementList, +) +from ..recipe import MultiFactorAuthRecipe + + +def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error as async_func, + ) + + return sync(async_func(session, factor_id, user_context)) + + +def get_mfa_requirements_for_auth( + session: SessionContainer, + user_context: Optional[Dict[str, Any]] = None, +) -> MFARequirementList: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_mfa_requirements_for_auth as async_func, + ) + + return sync(async_func(session, user_context)) + + +def mark_factor_as_complete_in_session( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + mark_factor_as_complete_in_session as async_func, + ) + + return sync(async_func(session, factor_id, user_context)) + + +def get_factors_setup_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_factors_setup_for_user as async_func, + ) + + return sync(async_func(user_id, user_context)) + + +def get_required_secondary_factors_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, user_context)) + + +def add_to_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, factor_id, user_context)) + + +def remove_from_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + remove_from_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, factor_id, user_context)) + + +init = MultiFactorAuthRecipe.init diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index a301bd85e..23b4d30d4 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -37,7 +37,6 @@ import time from typing_extensions import Literal from supertokens_python.utils import log_debug_message -from ..multitenancy.interfaces import GetTenantOkResult if TYPE_CHECKING: from .types import OverrideConfig, MultiFactorAuthConfig @@ -276,7 +275,6 @@ async def is_valid_first_factor( configured_first_factors = mt_recipe.all_available_first_factors if is_factor_configured_for_tenant( - tenant_config=tenant_config, all_available_first_factors=mt_recipe.all_available_first_factors, first_factors=configured_first_factors, factor_id=factor_id, @@ -287,7 +285,6 @@ async def is_valid_first_factor( def is_factor_configured_for_tenant( - tenant_config: GetTenantOkResult, all_available_first_factors: List[str], first_factors: List[str], factor_id: str, diff --git a/supertokens_python/recipe/multitenancy/api/implementation.py b/supertokens_python/recipe/multitenancy/api/implementation.py index 77b3ecaf4..38f599352 100644 --- a/supertokens_python/recipe/multitenancy/api/implementation.py +++ b/supertokens_python/recipe/multitenancy/api/implementation.py @@ -52,7 +52,7 @@ async def login_methods_get( raise Exception("Tenant not found") provider_inputs_from_static = api_options.static_third_party_providers - provider_configs_from_core = tenant_config.third_party.providers + provider_configs_from_core = tenant_config.third_party_providers merged_providers = merge_providers_from_core_and_static( provider_configs_from_core, @@ -82,10 +82,10 @@ async def login_methods_get( return LoginMethodsGetOkResult( email_password=LoginMethodEmailPassword( - tenant_config.emailpassword.enabled + tenant_config.email_password_enabled ), - passwordless=LoginMethodPasswordless(tenant_config.passwordless.enabled), + passwordless=LoginMethodPasswordless(tenant_config.passwordless_enabled), third_party=LoginMethodThirdParty( - tenant_config.third_party.enabled, final_provider_list + tenant_config.third_party_enabled, final_provider_list ), ) diff --git a/supertokens_python/recipe/multitenancy/asyncio/__init__.py b/supertokens_python/recipe/multitenancy/asyncio/__init__.py index 1998cae3f..cfc495c25 100644 --- a/supertokens_python/recipe/multitenancy/asyncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/asyncio/__init__.py @@ -18,7 +18,6 @@ TenantConfig, CreateOrUpdateTenantOkResult, DeleteTenantOkResult, - GetTenantOkResult, ListAllTenantsOkResult, CreateOrUpdateThirdPartyConfigOkResult, DeleteThirdPartyConfigOkResult, @@ -61,7 +60,7 @@ async def delete_tenant( async def get_tenant( tenant_id: str, user_context: Optional[Dict[str, Any]] = None -) -> Optional[GetTenantOkResult]: +) -> Optional[TenantConfig]: if user_context is None: user_context = {} recipe = MultitenancyRecipe.get_instance() diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 1e02d3b34..04873ef04 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -28,32 +28,34 @@ class TenantConfig: + # pylint: disable=dangerous-default-value def __init__( self, - email_password_enabled: Union[bool, None] = None, - passwordless_enabled: Union[bool, None] = None, - third_party_enabled: Union[bool, None] = None, - core_config: Union[Dict[str, Any], None] = None, + tenant_id: str = "", + third_party_providers: List[ProviderConfig] = [], + email_password_enabled: bool = False, + passwordless_enabled: bool = False, + third_party_enabled: bool = False, + core_config: Dict[str, Any] = {}, first_factors: Optional[List[str]] = None, required_secondary_factors: Optional[List[str]] = None, ): + self.tenant_id = tenant_id self.email_password_enabled = email_password_enabled self.passwordless_enabled = passwordless_enabled self.third_party_enabled = third_party_enabled self.core_config = core_config self.first_factors = first_factors self.required_secondary_factors = required_secondary_factors + self.third_party_providers = third_party_providers def to_json(self) -> Dict[str, Any]: res: Dict[str, Any] = {} - if self.email_password_enabled is not None: - res["emailPasswordEnabled"] = self.email_password_enabled - if self.passwordless_enabled is not None: - res["passwordlessEnabled"] = self.passwordless_enabled - if self.third_party_enabled is not None: - res["thirdPartyEnabled"] = self.third_party_enabled - if self.core_config is not None: - res["coreConfig"] = self.core_config + res["tenantId"] = self.tenant_id + res["emailPasswordEnabled"] = self.email_password_enabled + res["passwordlessEnabled"] = self.passwordless_enabled + res["thirdPartyEnabled"] = self.third_party_enabled + res["coreConfig"] = self.core_config return res @@ -71,84 +73,10 @@ def __init__(self, did_exist: bool): self.did_exist = did_exist -class EmailPasswordConfig: - def __init__(self, enabled: bool): - self.enabled = enabled - - def to_json(self): - return {"enabled": self.enabled} - - -class PasswordlessConfig: - def __init__(self, enabled: bool): - self.enabled = enabled - - def to_json(self): - return {"enabled": self.enabled} - - -class ThirdPartyConfig: - def __init__(self, enabled: bool, providers: List[ProviderConfig]): - self.enabled = enabled - self.providers = providers - - def to_json(self): - return { - "enabled": self.enabled, - "providers": [provider.to_json() for provider in self.providers], - } - - -class TenantConfigResponse: - def __init__( - self, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - core_config: Dict[str, Any], - first_factors: Optional[List[str]] = None, - required_secondary_factors: Optional[List[str]] = None, - ): - self.emailpassword = emailpassword - self.passwordless = passwordless - self.third_party = third_party - self.core_config = core_config - self.first_factors = first_factors - self.required_secondary_factors = required_secondary_factors - - -class GetTenantOkResult(TenantConfigResponse): - status = "OK" - - -class ListAllTenantsItem(TenantConfigResponse): - def __init__( - self, - tenant_id: str, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - core_config: Dict[str, Any], - ): - super().__init__(emailpassword, passwordless, third_party, core_config) - self.tenant_id = tenant_id - - def to_json(self): - res = { - "tenantId": self.tenant_id, - "emailPassword": self.emailpassword.to_json(), - "passwordless": self.passwordless.to_json(), - "thirdParty": self.third_party.to_json(), - "coreConfig": self.core_config, - } - - return res - - class ListAllTenantsOkResult: status = "OK" - def __init__(self, tenants: List[ListAllTenantsItem]): + def __init__(self, tenants: List[TenantConfig]): self.tenants = tenants @@ -224,7 +152,7 @@ async def delete_tenant( @abstractmethod async def get_tenant( self, tenant_id: str, user_context: Dict[str, Any] - ) -> Optional[GetTenantOkResult]: + ) -> Optional[TenantConfig]: pass @abstractmethod diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 30a1105b4..1e2ef5417 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -28,15 +28,9 @@ TenantConfig, CreateOrUpdateTenantOkResult, DeleteTenantOkResult, - TenantConfigResponse, - GetTenantOkResult, - EmailPasswordConfig, - PasswordlessConfig, - ThirdPartyConfig, ListAllTenantsOkResult, CreateOrUpdateThirdPartyConfigOkResult, DeleteThirdPartyConfigOkResult, - ListAllTenantsItem, ) if TYPE_CHECKING: @@ -48,7 +42,7 @@ from .constants import DEFAULT_TENANT_ID -def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: +def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfig: from supertokens_python.recipe.thirdparty.provider import ( UserInfoMap, UserFields, @@ -109,13 +103,12 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: ) ) - return TenantConfigResponse( - emailpassword=EmailPasswordConfig(tenant["emailPassword"]["enabled"]), - passwordless=PasswordlessConfig(tenant["passwordless"]["enabled"]), - third_party=ThirdPartyConfig( - tenant["thirdParty"]["enabled"], - providers, - ), + return TenantConfig( + tenant_id=tenant["tenantId"], + email_password_enabled=tenant["emailPassword"]["enabled"], + passwordless_enabled=tenant["passwordless"]["enabled"], + third_party_providers=providers, + third_party_enabled=tenant["thirdParty"]["enabled"], core_config=tenant["coreConfig"], ) @@ -163,7 +156,7 @@ async def delete_tenant( async def get_tenant( self, tenant_id: Optional[str], user_context: Dict[str, Any] - ) -> Optional[GetTenantOkResult]: + ) -> Optional[TenantConfig]: res = await self.querier.send_get_request( NormalisedURLPath( f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant" @@ -177,12 +170,7 @@ async def get_tenant( tenant_config = parse_tenant_config(res) - return GetTenantOkResult( - emailpassword=tenant_config.emailpassword, - passwordless=tenant_config.passwordless, - third_party=tenant_config.third_party, - core_config=tenant_config.core_config, - ) + return tenant_config async def list_all_tenants( self, user_context: Dict[str, Any] @@ -193,18 +181,11 @@ async def list_all_tenants( user_context=user_context, ) - tenant_items: List[ListAllTenantsItem] = [] + tenant_items: List[TenantConfig] = [] for tenant in response["tenants"]: config = parse_tenant_config(tenant) - item = ListAllTenantsItem( - tenant["tenantId"], - config.emailpassword, - config.passwordless, - config.third_party, - config.core_config, - ) - tenant_items.append(item) + tenant_items.append(config) return ListAllTenantsOkResult( tenants=tenant_items, diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index 328fc0129..0647212ea 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -203,7 +203,7 @@ async def get_provider( raise Exception("Tenant not found") merged_providers = merge_providers_from_core_and_static( - provider_configs_from_core=tenant_config.third_party.providers, + provider_configs_from_core=tenant_config.third_party_providers, provider_inputs_from_static=self.providers, include_all_providers=tenant_id == DEFAULT_TENANT_ID, ) diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 2fb10fe00..051dc01b8 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -76,32 +76,32 @@ async def test_tenant_crud(): t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is False - assert t1_config.third_party.enabled is False + assert t1_config.email_password_enabled is True + assert t1_config.passwordless_enabled is False + assert t1_config.third_party_enabled is False assert t1_config.core_config == {} t2_config = await get_tenant("t2") assert t2_config is not None - assert t2_config.emailpassword.enabled is False - assert t2_config.passwordless.enabled is True - assert t2_config.third_party.enabled is False + assert t2_config.email_password_enabled is False + assert t2_config.passwordless_enabled is True + assert t2_config.third_party_enabled is False assert t2_config.core_config == {} t3_config = await get_tenant("t3") assert t3_config is not None - assert t3_config.emailpassword.enabled is False - assert t3_config.passwordless.enabled is False - assert t3_config.third_party.enabled is True + assert t3_config.email_password_enabled is False + assert t3_config.passwordless_enabled is False + assert t3_config.third_party_enabled is True assert t3_config.core_config == {} # update tenant1 to add passwordless: await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True)) t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is True - assert t1_config.third_party.enabled is False + assert t1_config.email_password_enabled is True + assert t1_config.passwordless_enabled is True + assert t1_config.third_party_enabled is False assert t1_config.core_config == {} # update tenant1 to add thirdparty: @@ -109,9 +109,9 @@ async def test_tenant_crud(): t1_config = await get_tenant("t1") assert t1_config is not None assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is True - assert t1_config.third_party.enabled is True + assert t1_config.email_password_enabled is True + assert t1_config.passwordless_enabled is True + assert t1_config.third_party_enabled is True assert t1_config.core_config == {} # delete tenant2: @@ -139,8 +139,8 @@ async def test_tenant_thirdparty_config(): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 1 - provider = tenant_config.third_party.providers[0] + assert len(tenant_config.third_party_providers) == 1 + provider = tenant_config.third_party_providers[0] assert provider.third_party_id == "google" assert provider.clients is not None assert len(provider.clients) == 1 @@ -197,8 +197,8 @@ async def generate_fake_email(_: str, __: str, ___: Dict[str, Any]): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 1 - provider = tenant_config.third_party.providers[0] + assert len(tenant_config.third_party_providers) == 1 + provider = tenant_config.third_party_providers[0] assert provider.third_party_id == "google" assert provider.name == "Custom name" assert provider.clients is not None @@ -244,7 +244,7 @@ async def generate_fake_email(_: str, __: str, ___: Dict[str, Any]): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 0 + assert len(tenant_config.third_party_providers) == 0 async def test_user_association_and_disassociation_with_tenants(): diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index 6af638aae..5f9625c38 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -62,9 +62,9 @@ def get_tenant(): # type: ignore { "status": "OK", "tenant": { - "emailPassword": response.emailpassword.to_json(), - "thirdParty": response.third_party.to_json(), - "passwordless": response.passwordless.to_json(), + "emailPassword": {"enabled": response.email_password_enabled}, + "thirdParty": {"enabled": response.third_party_enabled}, + "passwordless": {"enabled": response.passwordless_enabled}, "coreConfig": response.core_config, }, } @@ -119,32 +119,34 @@ def create_or_update_third_party_config(): # type: ignore user_info_endpoint_headers=config.get("userInfoEndpointHeaders"), jwks_uri=config.get("jwksURI"), oidc_discovery_endpoint=config.get("oidcDiscoveryEndpoint"), - user_info_map=UserInfoMap( - from_id_token_payload=UserFields( - user_id=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("userId"), - email=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("email"), - email_verified=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("emailVerified"), - ), - from_user_info_api=UserFields( - user_id=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("userId"), - email=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("email"), - email_verified=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("emailVerified"), - ), - ) - if "userInfoMap" in config - else None, + user_info_map=( + UserInfoMap( + from_id_token_payload=UserFields( + user_id=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("userId"), + email=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("email"), + email_verified=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("emailVerified"), + ), + from_user_info_api=UserFields( + user_id=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("userId"), + email=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("email"), + email_verified=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("emailVerified"), + ), + ) + if "userInfoMap" in config + else None + ), require_email=config.get("requireEmail", True), ) ) From ad9ae3579efac18a579f5d8bf70f9b9a27851303 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 25 Aug 2024 20:31:33 +0530 Subject: [PATCH 020/126] starts working on totp --- CONTRIBUTING.md | 7 +- supertokens_python/recipe/totp/interfaces.py | 188 +++++++++++++++++++ 2 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 supertokens_python/recipe/totp/interfaces.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 97de23a18..b9fa248d4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,4 +81,9 @@ Additional resources you might find useful: This will generate the API docs in a folder called docs ``` make build-docs -``` \ No newline at end of file +``` + +## AI prompt rules for generating code (with Claude sonnet-3.5): +- Avoid using TypeDict +- Avoid using generic Dict as much as possible, except when defining the types for user_context. Instead, create a custom class to indicate that type. If a function has multiple status Literal output, then there should be one class per status. +- Variable and function names should be in snake_case. Class names in PascalCase. \ No newline at end of file diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py new file mode 100644 index 000000000..90a60cc11 --- /dev/null +++ b/supertokens_python/recipe/totp/interfaces.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any, Union +from typing_extensions import Literal +from abc import ABC, abstractmethod + + +class OkResult: + def __init__(self): + self.status: Literal["OK"] = "OK" + + +class UserIdentifierInfoOkResult(OkResult): + def __init__(self, info: str): + super().__init__() + self.info: str = info + + +class UnknownUserIdError: + def __init__(self): + self.status: Literal["UNKNOWN_USER_ID_ERROR"] = "UNKNOWN_USER_ID_ERROR" + + +class UserIdentifierInfoDoesNotExistError: + def __init__(self): + self.status: Literal[ + "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + ] = "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + + +class CreateDeviceOkResult(OkResult): + def __init__(self, device_name: str, secret: str, qr_code_string: str): + super().__init__() + self.device_name: str = device_name + self.secret: str = secret + self.qr_code_string: str = qr_code_string + + +class DeviceAlreadyExistsError: + def __init__(self): + self.status: Literal[ + "DEVICE_ALREADY_EXISTS_ERROR" + ] = "DEVICE_ALREADY_EXISTS_ERROR" + + +class UpdateDeviceOkResult(OkResult): + pass + + +class UnknownDeviceError: + def __init__(self): + self.status: Literal["UNKNOWN_DEVICE_ERROR"] = "UNKNOWN_DEVICE_ERROR" + + +class Device: + def __init__(self, name: str, period: int, skew: int, verified: bool): + self.name: str = name + self.period: int = period + self.skew: int = skew + self.verified: bool = verified + + +class ListDevicesOkResult(OkResult): + def __init__(self, devices: List[Device]): + super().__init__() + self.devices: List[Device] = devices + + +class RemoveDeviceOkResult(OkResult): + def __init__(self, did_device_exist: bool): + super().__init__() + self.did_device_exist: bool = did_device_exist + + +class VerifyDeviceOkResult(OkResult): + def __init__( + self, + was_already_verified: bool, + ): + super().__init__() + self.was_already_verified: bool = was_already_verified + + +class InvalidTOTPError: + def __init__( + self, current_number_of_failed_attempts: int, max_number_of_failed_attempts: int + ): + self.status: Literal["INVALID_TOTP_ERROR"] = "INVALID_TOTP_ERROR" + self.current_number_of_failed_attempts: int = current_number_of_failed_attempts + self.max_number_of_failed_attempts: int = max_number_of_failed_attempts + + +class LimitReachedError: + def __init__(self, retry_after_ms: int): + self.status: Literal["LIMIT_REACHED_ERROR"] = "LIMIT_REACHED_ERROR" + self.retry_after_ms: int = retry_after_ms + + +class VerifyTOTPOkResult(OkResult): + def __init__( + self, + ): + super().__init__() + + +class RecipeInterface(ABC): + @abstractmethod + async def get_user_identifier_info_for_user_id( + self, user_id: str, user_context: Dict[str, Any] + ) -> Union[ + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + ]: + pass + + @abstractmethod + async def create_device( + self, + user_id: str, + user_identifier_info: Union[str, None] = None, + device_name: Union[str, None] = None, + skew: Union[int, None] = None, + period: Union[int, None] = None, + user_context: Union[Dict[str, Any], None] = None, + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError,]: + pass + + @abstractmethod + async def update_device( + self, + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Dict[str, Any], + ) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError,]: + pass + + @abstractmethod + async def list_devices( + self, user_id: str, user_context: Dict[str, Any] + ) -> ListDevicesOkResult: + pass + + @abstractmethod + async def remove_device( + self, user_id: str, device_name: str, user_context: Dict[str, Any] + ) -> RemoveDeviceOkResult: + pass + + @abstractmethod + async def verify_device( + self, + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + ]: + pass + + @abstractmethod + async def verify_totp( + self, tenant_id: str, user_id: str, totp: str, user_context: Dict[str, Any] + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + ]: + pass From 1c15536f720eac9bf070f85b17c89f1ae4ab0681 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 25 Aug 2024 20:53:11 +0530 Subject: [PATCH 021/126] more code --- .cursorrules | 10 ++ CONTRIBUTING.md | 7 +- supertokens_python/recipe/totp/interfaces.py | 144 ++++++++++++++++++- 3 files changed, 148 insertions(+), 13 deletions(-) create mode 100644 .cursorrules diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 000000000..89a708854 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,10 @@ +You are an expert Python and Typescript developer. Your job is to convert node code (in typescript) into Python code. The python code then goes into this SDK. The python code style should keep in mind: +- Avoid using TypeDict +- Avoid using generic Dict as much as possible, except when defining the types for `user_context`. +- If a function has multiple `status` strings as outputs, then define one unique class per unique `status` string. The class name should be such that it indicates the status it is associated with. +- Variable and function names should be in snake_case. Class names in PascalCase. +- Whenever importing `Literal`, import it from `typing_extensions`, and not `types`. +- Do not use `|` for OR type, instead use `Union` +- When defining API interface functions, make sure the output classes inherit from `APIResponse` class, and that they have a `to_json` function defined whose output matches the structure of the provided Typescript code output objects. + +The semantic of the python code should be the same as what's of the provided Typescript code. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9fa248d4..97de23a18 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,9 +81,4 @@ Additional resources you might find useful: This will generate the API docs in a folder called docs ``` make build-docs -``` - -## AI prompt rules for generating code (with Claude sonnet-3.5): -- Avoid using TypeDict -- Avoid using generic Dict as much as possible, except when defining the types for user_context. Instead, create a custom class to indicate that type. If a function has multiple status Literal output, then there should be one class per status. -- Variable and function names should be in snake_case. Class names in PascalCase. \ No newline at end of file +``` \ No newline at end of file diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index 90a60cc11..f1491cd88 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -16,8 +16,11 @@ from typing_extensions import Literal from abc import ABC, abstractmethod +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import APIResponse, GeneralErrorResponse -class OkResult: + +class OkResult(APIResponse): def __init__(self): self.status: Literal["OK"] = "OK" @@ -28,10 +31,13 @@ def __init__(self, info: str): self.info: str = info -class UnknownUserIdError: +class UnknownUserIdError(APIResponse): def __init__(self): self.status: Literal["UNKNOWN_USER_ID_ERROR"] = "UNKNOWN_USER_ID_ERROR" + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + class UserIdentifierInfoDoesNotExistError: def __init__(self): @@ -47,42 +53,76 @@ def __init__(self, device_name: str, secret: str, qr_code_string: str): self.secret: str = secret self.qr_code_string: str = qr_code_string + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "deviceName": self.device_name, + "secret": self.secret, + "qrCodeString": self.qr_code_string, + } -class DeviceAlreadyExistsError: + +class DeviceAlreadyExistsError(APIResponse): def __init__(self): self.status: Literal[ "DEVICE_ALREADY_EXISTS_ERROR" ] = "DEVICE_ALREADY_EXISTS_ERROR" + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + class UpdateDeviceOkResult(OkResult): pass -class UnknownDeviceError: +class UnknownDeviceError(APIResponse): def __init__(self): self.status: Literal["UNKNOWN_DEVICE_ERROR"] = "UNKNOWN_DEVICE_ERROR" + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} -class Device: + +class Device(APIResponse): def __init__(self, name: str, period: int, skew: int, verified: bool): self.name: str = name self.period: int = period self.skew: int = skew self.verified: bool = verified + def to_json(self) -> Dict[str, Any]: + return { + "name": self.name, + "period": self.period, + "skew": self.skew, + "verified": self.verified, + } + class ListDevicesOkResult(OkResult): def __init__(self, devices: List[Device]): super().__init__() self.devices: List[Device] = devices + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "devices": [device.to_json() for device in self.devices], + } + class RemoveDeviceOkResult(OkResult): def __init__(self, did_device_exist: bool): super().__init__() self.did_device_exist: bool = did_device_exist + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "didDeviceExist": self.did_device_exist, + } + class VerifyDeviceOkResult(OkResult): def __init__( @@ -92,8 +132,14 @@ def __init__( super().__init__() self.was_already_verified: bool = was_already_verified + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "wasAlreadyVerified": self.was_already_verified, + } -class InvalidTOTPError: + +class InvalidTOTPError(APIResponse): def __init__( self, current_number_of_failed_attempts: int, max_number_of_failed_attempts: int ): @@ -101,12 +147,25 @@ def __init__( self.current_number_of_failed_attempts: int = current_number_of_failed_attempts self.max_number_of_failed_attempts: int = max_number_of_failed_attempts + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "currentNumberOfFailedAttempts": self.current_number_of_failed_attempts, + "maxNumberOfFailedAttempts": self.max_number_of_failed_attempts, + } + -class LimitReachedError: +class LimitReachedError(APIResponse): def __init__(self, retry_after_ms: int): self.status: Literal["LIMIT_REACHED_ERROR"] = "LIMIT_REACHED_ERROR" self.retry_after_ms: int = retry_after_ms + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "retryAfterMs": self.retry_after_ms, + } + class VerifyTOTPOkResult(OkResult): def __init__( @@ -114,6 +173,9 @@ def __init__( ): super().__init__() + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + class RecipeInterface(ABC): @abstractmethod @@ -186,3 +248,71 @@ async def verify_totp( LimitReachedError, ]: pass + + +class APIOptions: + pass + + +class APIInterface(ABC): + @abstractmethod + async def create_device_post( + self, + device_name: Union[str, None], + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, GeneralErrorResponse]: + pass + + @abstractmethod + async def list_devices_get( + self, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ListDevicesOkResult, GeneralErrorResponse]: + pass + + @abstractmethod + async def remove_device_post( + self, + device_name: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[RemoveDeviceOkResult, GeneralErrorResponse]: + pass + + @abstractmethod + async def verify_device_post( + self, + device_name: str, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + pass + + @abstractmethod + async def verify_totp_post( + self, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + pass From eb1a5b4cb17d66f936fd1bc9e822be4b87685548 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 26 Aug 2024 15:26:09 +0530 Subject: [PATCH 022/126] adds totp recipe --- .../recipe/multifactorauth/recipe.py | 4 + .../recipe/totp/api/create_device.py | 58 ++++ .../recipe/totp/api/implementation.py | 173 ++++++++++++ .../recipe/totp/api/list_devices.py | 48 ++++ .../recipe/totp/api/remove_device.py | 60 ++++ .../recipe/totp/api/verify_device.py | 66 +++++ .../recipe/totp/api/verify_totp.py | 61 ++++ .../recipe/totp/asyncio/__init__.py | 126 +++++++++ supertokens_python/recipe/totp/constants.py | 19 ++ supertokens_python/recipe/totp/interfaces.py | 201 +++---------- supertokens_python/recipe/totp/recipe.py | 238 ++++++++++++++++ .../recipe/totp/recipe_implementation.py | 266 ++++++++++++++++++ .../recipe/totp/syncio/__init__.py | 129 +++++++++ supertokens_python/recipe/totp/types.py | 221 +++++++++++++++ supertokens_python/recipe/totp/utils.py | 44 +++ 15 files changed, 1548 insertions(+), 166 deletions(-) create mode 100644 supertokens_python/recipe/totp/api/create_device.py create mode 100644 supertokens_python/recipe/totp/api/implementation.py create mode 100644 supertokens_python/recipe/totp/api/list_devices.py create mode 100644 supertokens_python/recipe/totp/api/remove_device.py create mode 100644 supertokens_python/recipe/totp/api/verify_device.py create mode 100644 supertokens_python/recipe/totp/api/verify_totp.py create mode 100644 supertokens_python/recipe/totp/asyncio/__init__.py create mode 100644 supertokens_python/recipe/totp/constants.py create mode 100644 supertokens_python/recipe/totp/recipe.py create mode 100644 supertokens_python/recipe/totp/recipe_implementation.py create mode 100644 supertokens_python/recipe/totp/syncio/__init__.py create mode 100644 supertokens_python/recipe/totp/types.py create mode 100644 supertokens_python/recipe/totp/utils.py diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index c8882b07c..30bdcaeb1 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -188,6 +188,10 @@ def get_instance_or_throw_error() -> MultiFactorAuthRecipe: "MultiFactorAuth recipe initialisation not done. Did you forget to call the SuperTokens.init function?" ) + @staticmethod + def get_instance() -> Optional[MultiFactorAuthRecipe]: + return MultiFactorAuthRecipe.__instance + @staticmethod def reset(): if ("SUPERTOKENS_ENV" not in environ) or ( diff --git a/supertokens_python/recipe/totp/api/create_device.py b/supertokens_python/recipe/totp/api/create_device.py new file mode 100644 index 000000000..7a7ba2ef2 --- /dev/null +++ b/supertokens_python/recipe/totp/api/create_device.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union + +from supertokens_python.framework import BaseResponse + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + + +async def handle_create_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_create_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + device_name = body.get("deviceName") + + if device_name is not None and not isinstance(device_name, str): + raise_bad_input_exception("deviceName must be a string") + + response = await api_implementation.create_device_post( + device_name, api_options, session, user_context + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/implementation.py b/supertokens_python/recipe/totp/api/implementation.py new file mode 100644 index 000000000..3713a7c4d --- /dev/null +++ b/supertokens_python/recipe/totp/api/implementation.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.multifactorauth.asyncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from ..interfaces import APIInterface, APIOptions +from ..types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, + UnknownUserIdError, +) +from supertokens_python.types import GeneralErrorResponse + + +class APIImplementation(APIInterface): + async def create_device_post( + self, + device_name: Union[str, None], + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, GeneralErrorResponse]: + user_id = session.get_user_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + await assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session, "totp", user_context + ) + + create_device_res = await options.recipe_implementation.create_device( + user_id=user_id, + user_identifier_info=None, + device_name=device_name, + skew=None, + period=None, + user_context=user_context, + ) + + if isinstance(create_device_res, UnknownUserIdError): + raise UnauthorisedError("Session user not found") + return create_device_res + + async def list_devices_get( + self, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ListDevicesOkResult, GeneralErrorResponse]: + user_id = session.get_user_id() + + return await options.recipe_implementation.list_devices( + user_id=user_id, + user_context=user_context, + ) + + async def remove_device_post( + self, + device_name: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[RemoveDeviceOkResult, GeneralErrorResponse]: + user_id = session.get_user_id() + + return await options.recipe_implementation.remove_device( + user_id=user_id, + device_name=device_name, + user_context=user_context, + ) + + async def verify_device_post( + self, + device_name: str, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + user_id = session.get_user_id() + tenant_id = session.get_tenant_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + await assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session, "totp", user_context + ) + + res = await options.recipe_implementation.verify_device( + tenant_id=tenant_id, + user_id=user_id, + device_name=device_name, + totp=totp, + user_context=user_context, + ) + + if isinstance(res, VerifyDeviceOkResult): + await mfa_instance.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id="totp", + user_context=user_context, + ) + + return res + + async def verify_totp_post( + self, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + user_id = session.get_user_id() + tenant_id = session.get_tenant_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + res = await options.recipe_implementation.verify_totp( + tenant_id=tenant_id, + user_id=user_id, + totp=totp, + user_context=user_context, + ) + + if isinstance(res, VerifyTOTPOkResult): + await mfa_instance.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id="totp", + user_context=user_context, + ) + + return res diff --git a/supertokens_python/recipe/totp/api/list_devices.py b/supertokens_python/recipe/totp/api/list_devices.py new file mode 100644 index 000000000..3030fa3a3 --- /dev/null +++ b/supertokens_python/recipe/totp/api/list_devices.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_list_devices_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_list_devices_get: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + response = await api_implementation.list_devices_get( + api_options, session, user_context + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/remove_device.py b/supertokens_python/recipe/totp/api/remove_device.py new file mode 100644 index 000000000..128424b94 --- /dev/null +++ b/supertokens_python/recipe/totp/api/remove_device.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_remove_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_remove_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + device_name = body.get("deviceName") + + if device_name is None or not isinstance(device_name, str) or len(device_name) == 0: + raise Exception("deviceName is required and must be a non-empty string") + + response = await api_implementation.remove_device_post( + device_name=device_name, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/verify_device.py b/supertokens_python/recipe/totp/api/verify_device.py new file mode 100644 index 000000000..b57e091b1 --- /dev/null +++ b/supertokens_python/recipe/totp/api/verify_device.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_verify_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_verify_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + device_name = body.get("deviceName") + totp = body.get("totp") + + if device_name is None or not isinstance(device_name, str): + raise Exception("deviceName is required and must be a string") + + if totp is None or not isinstance(totp, str): + raise Exception("totp is required and must be a string") + + response = await api_implementation.verify_device_post( + device_name=device_name, + totp=totp, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/verify_totp.py b/supertokens_python/recipe/totp/api/verify_totp.py new file mode 100644 index 000000000..1dbe91218 --- /dev/null +++ b/supertokens_python/recipe/totp/api/verify_totp.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_verify_totp_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_verify_totp_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + totp = body.get("totp") + + if totp is None or not isinstance(totp, str): + raise Exception("totp is required and must be a string") + + response = await api_implementation.verify_totp_post( + totp=totp, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/asyncio/__init__.py b/supertokens_python/recipe/totp/asyncio/__init__.py new file mode 100644 index 000000000..5d83bbabb --- /dev/null +++ b/supertokens_python/recipe/totp/asyncio/__init__.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Union, Optional +from ..recipe import TOTPRecipe +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UnknownUserIdError, + UpdateDeviceOkResult, + UnknownDeviceError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, +) + + +async def create_device( + user_id: str, + user_identifier_info: Optional[str] = None, + device_name: Optional[str] = None, + skew: Optional[int] = None, + period: Optional[int] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.create_device( + user_id, + user_identifier_info, + device_name, + skew, + period, + user_context, + ) + + +async def update_device( + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.update_device( + user_id, + existing_device_name, + new_device_name, + user_context, + ) + + +async def list_devices( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> ListDevicesOkResult: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.list_devices( + user_id, + user_context, + ) + + +async def remove_device( + user_id: str, + device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> RemoveDeviceOkResult: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.remove_device( + user_id, + device_name, + user_context, + ) + + +async def verify_device( + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + VerifyDeviceOkResult, UnknownDeviceError, InvalidTOTPError, LimitReachedError +]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.verify_device( + tenant_id, + user_id, + device_name, + totp, + user_context, + ) + + +async def verify_totp( + tenant_id: str, + user_id: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[VerifyTOTPOkResult, UnknownUserIdError, InvalidTOTPError, LimitReachedError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.verify_totp( + tenant_id, + user_id, + totp, + user_context, + ) diff --git a/supertokens_python/recipe/totp/constants.py b/supertokens_python/recipe/totp/constants.py new file mode 100644 index 000000000..bf421c46c --- /dev/null +++ b/supertokens_python/recipe/totp/constants.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +CREATE_TOTP_DEVICE = "/totp/device" +LIST_TOTP_DEVICES = "/totp/device/list" +REMOVE_TOTP_DEVICE = "/totp/device/remove" +VERIFY_TOTP_DEVICE = "/totp/device/verify" +VERIFY_TOTP = "/totp/verify" diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index f1491cd88..6b35b9b47 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -12,169 +12,15 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import List, Dict, Any, Union -from typing_extensions import Literal +from typing import Dict, Any, Union from abc import ABC, abstractmethod +from supertokens_python import AppInfo +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import APIResponse, GeneralErrorResponse - - -class OkResult(APIResponse): - def __init__(self): - self.status: Literal["OK"] = "OK" - - -class UserIdentifierInfoOkResult(OkResult): - def __init__(self, info: str): - super().__init__() - self.info: str = info - - -class UnknownUserIdError(APIResponse): - def __init__(self): - self.status: Literal["UNKNOWN_USER_ID_ERROR"] = "UNKNOWN_USER_ID_ERROR" - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} - - -class UserIdentifierInfoDoesNotExistError: - def __init__(self): - self.status: Literal[ - "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" - ] = "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" - - -class CreateDeviceOkResult(OkResult): - def __init__(self, device_name: str, secret: str, qr_code_string: str): - super().__init__() - self.device_name: str = device_name - self.secret: str = secret - self.qr_code_string: str = qr_code_string - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "deviceName": self.device_name, - "secret": self.secret, - "qrCodeString": self.qr_code_string, - } - - -class DeviceAlreadyExistsError(APIResponse): - def __init__(self): - self.status: Literal[ - "DEVICE_ALREADY_EXISTS_ERROR" - ] = "DEVICE_ALREADY_EXISTS_ERROR" - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} - - -class UpdateDeviceOkResult(OkResult): - pass - - -class UnknownDeviceError(APIResponse): - def __init__(self): - self.status: Literal["UNKNOWN_DEVICE_ERROR"] = "UNKNOWN_DEVICE_ERROR" - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} - - -class Device(APIResponse): - def __init__(self, name: str, period: int, skew: int, verified: bool): - self.name: str = name - self.period: int = period - self.skew: int = skew - self.verified: bool = verified - - def to_json(self) -> Dict[str, Any]: - return { - "name": self.name, - "period": self.period, - "skew": self.skew, - "verified": self.verified, - } - - -class ListDevicesOkResult(OkResult): - def __init__(self, devices: List[Device]): - super().__init__() - self.devices: List[Device] = devices - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "devices": [device.to_json() for device in self.devices], - } - - -class RemoveDeviceOkResult(OkResult): - def __init__(self, did_device_exist: bool): - super().__init__() - self.did_device_exist: bool = did_device_exist - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "didDeviceExist": self.did_device_exist, - } - - -class VerifyDeviceOkResult(OkResult): - def __init__( - self, - was_already_verified: bool, - ): - super().__init__() - self.was_already_verified: bool = was_already_verified - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "wasAlreadyVerified": self.was_already_verified, - } - - -class InvalidTOTPError(APIResponse): - def __init__( - self, current_number_of_failed_attempts: int, max_number_of_failed_attempts: int - ): - self.status: Literal["INVALID_TOTP_ERROR"] = "INVALID_TOTP_ERROR" - self.current_number_of_failed_attempts: int = current_number_of_failed_attempts - self.max_number_of_failed_attempts: int = max_number_of_failed_attempts - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "currentNumberOfFailedAttempts": self.current_number_of_failed_attempts, - "maxNumberOfFailedAttempts": self.max_number_of_failed_attempts, - } - - -class LimitReachedError(APIResponse): - def __init__(self, retry_after_ms: int): - self.status: Literal["LIMIT_REACHED_ERROR"] = "LIMIT_REACHED_ERROR" - self.retry_after_ms: int = retry_after_ms - - def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "retryAfterMs": self.retry_after_ms, - } - - -class VerifyTOTPOkResult(OkResult): - def __init__( - self, - ): - super().__init__() - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} +from supertokens_python.recipe.totp.recipe import TOTPRecipe +from supertokens_python.types import GeneralErrorResponse +from .types import * class RecipeInterface(ABC): @@ -192,11 +38,11 @@ async def get_user_identifier_info_for_user_id( async def create_device( self, user_id: str, - user_identifier_info: Union[str, None] = None, - device_name: Union[str, None] = None, - skew: Union[int, None] = None, - period: Union[int, None] = None, - user_context: Union[Dict[str, Any], None] = None, + user_identifier_info: Optional[str], + device_name: Optional[str], + skew: Optional[int], + period: Optional[int], + user_context: Dict[str, Any], ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError,]: pass @@ -251,10 +97,33 @@ async def verify_totp( class APIOptions: - pass + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: TOTPNormalisedConfig, + recipe_implementation: RecipeInterface, + app_info: AppInfo, + recipe_instance: TOTPRecipe, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config = config + self.recipe_implementation: RecipeInterface = recipe_implementation + self.app_info = app_info + self.recipe_instance = recipe_instance class APIInterface(ABC): + def __init__(self): + self.disable_create_device_post = False + self.disable_list_devices_get = False + self.disable_remove_device_post = False + self.disable_verify_device_post = False + self.disable_verify_totp_post = False + @abstractmethod async def create_device_post( self, diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py new file mode 100644 index 000000000..d9438346c --- /dev/null +++ b/supertokens_python/recipe/totp/recipe.py @@ -0,0 +1,238 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from os import environ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.multifactorauth.types import ( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetFactorsSetupForUserFromOtherRecipesFunc, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.querier import Querier +from supertokens_python.types import AccountLinkingUser + +from .recipe_implementation import RecipeImplementation +from .api.implementation import APIImplementation +from .interfaces import APIInterface, RecipeInterface +from .utils import validate_and_normalise_user_input +from ...post_init_callbacks import PostSTInitCallbacks +from ..multifactorauth.recipe import MultiFactorAuthRecipe + +if TYPE_CHECKING: + from supertokens_python.framework.request import BaseRequest + from supertokens_python.framework.response import BaseResponse + from supertokens_python.supertokens import AppInfo + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception + +from api.list_devices import handle_list_devices_api +from .api.create_device import handle_create_device_api +from .api.remove_device import handle_remove_device_api +from .api.verify_device import handle_verify_device_api +from .api.verify_totp import handle_verify_totp_api +from .constants import ( + CREATE_TOTP_DEVICE, + VERIFY_TOTP_DEVICE, + VERIFY_TOTP, + LIST_TOTP_DEVICES, + REMOVE_TOTP_DEVICE, +) +from .interfaces import APIOptions +from .types import TOTPConfig + + +class TOTPRecipe(RecipeModule): + recipe_id = "totp" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + config: Union[TOTPConfig, None] = None, + ): + super().__init__(recipe_id, app_info) + self.config = validate_and_normalise_user_input(app_info, config) + + recipe_implementation = RecipeImplementation( + Querier.get_instance(recipe_id), self.config + ) + self.recipe_implementation: RecipeInterface = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + + api_implementation = APIImplementation() + self.api_implementation: APIInterface = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) + + def callback(): + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return ["totp"] + + async def f2( + user: AccountLinkingUser, user_context: Dict[str, Any] + ) -> List[str]: + device_res = await TOTPRecipe.get_instance_or_throw().recipe_implementation.list_devices( + user_id=user.id, user_context=user_context + ) + for device in device_res.devices: + if device.verified: + return ["totp"] + return [] + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc(f2) + ) + + PostSTInitCallbacks.add_post_init_callback(callback) + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [ + APIHandled( + NormalisedURLPath(CREATE_TOTP_DEVICE), + "post", + CREATE_TOTP_DEVICE, + self.api_implementation.disable_create_device_post, + ), + APIHandled( + NormalisedURLPath(LIST_TOTP_DEVICES), + "get", + LIST_TOTP_DEVICES, + self.api_implementation.disable_list_devices_get, + ), + APIHandled( + NormalisedURLPath(REMOVE_TOTP_DEVICE), + "post", + REMOVE_TOTP_DEVICE, + self.api_implementation.disable_remove_device_post, + ), + APIHandled( + NormalisedURLPath(VERIFY_TOTP_DEVICE), + "post", + VERIFY_TOTP_DEVICE, + self.api_implementation.disable_verify_device_post, + ), + APIHandled( + NormalisedURLPath(VERIFY_TOTP), + "post", + VERIFY_TOTP, + self.api_implementation.disable_verify_totp_post, + ), + ] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ): + api_options = APIOptions( + request, + response, + self.recipe_id, + self.config, + self.recipe_implementation, + self.get_app_info(), + self, + ) + if request_id == CREATE_TOTP_DEVICE: + return await handle_create_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == LIST_TOTP_DEVICES: + return await handle_list_devices_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == REMOVE_TOTP_DEVICE: + return await handle_remove_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == VERIFY_TOTP_DEVICE: + return await handle_verify_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == VERIFY_TOTP: + return await handle_verify_totp_api( + tenant_id, self.api_implementation, api_options, user_context + ) + + return None + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + config: Union[TOTPConfig, None] = None, + ): + def func(app_info: AppInfo): + if TOTPRecipe.__instance is None: + TOTPRecipe.__instance = TOTPRecipe( + TOTPRecipe.recipe_id, + app_info, + config, + ) + return TOTPRecipe.__instance + raise Exception( + "TOTP recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance_or_throw() -> TOTPRecipe: + if TOTPRecipe.__instance is not None: + return TOTPRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + TOTPRecipe.__instance = None diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py new file mode 100644 index 000000000..5b4ec544f --- /dev/null +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -0,0 +1,266 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from urllib.parse import quote + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.totp.interfaces import ( + RecipeInterface, + CreateDeviceOkResult, + UnknownUserIdError, + UpdateDeviceOkResult, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + VerifyTOTPOkResult, + UserIdentifierInfoOkResult, + UserIdentifierInfoDoesNotExistError, +) +from .types import ( + Device, + DeviceAlreadyExistsError, + InvalidTOTPError, + LimitReachedError, + TOTPNormalisedConfig, + UnknownDeviceError, +) +from supertokens_python.asyncio import get_user + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + config: TOTPNormalisedConfig, + ): + super().__init__() + self.querier = querier + self.config = config + + async def get_user_identifier_info_for_user_id( + self, user_id: str, user_context: Dict[str, Any] + ) -> Union[ + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + ]: + user = await get_user(user_id, user_context) + + if user is None: + return UnknownUserIdError() + + primary_login_method = next( + ( + method + for method in user.login_methods + if method.recipe_user_id.get_as_string() == user.id + ), + None, + ) + + if primary_login_method is not None: + if primary_login_method.email is not None: + return UserIdentifierInfoOkResult(primary_login_method.email) + elif primary_login_method.phone_number is not None: + return UserIdentifierInfoOkResult(primary_login_method.phone_number) + + if user.emails: + return UserIdentifierInfoOkResult(user.emails[0]) + elif user.phone_numbers: + return UserIdentifierInfoOkResult(user.phone_numbers[0]) + + return UserIdentifierInfoDoesNotExistError() + + async def create_device( + self, + user_id: str, + user_identifier_info: Optional[str], + device_name: Optional[str], + skew: Optional[int], + period: Optional[int], + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError,]: + if user_identifier_info is None: + email_or_phone_info = await self.get_user_identifier_info_for_user_id( + user_id, user_context + ) + if isinstance(email_or_phone_info, UserIdentifierInfoOkResult): + user_identifier_info = email_or_phone_info.info + elif isinstance(email_or_phone_info, UnknownUserIdError): + return UnknownUserIdError() + + data = { + "userId": user_id, + "deviceName": device_name, + "skew": skew if skew is not None else self.config.default_skew, + "period": period if period is not None else self.config.default_period, + } + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/totp/device"), + data, + user_context=user_context, + ) + + qr_code_string = ( + f"otpauth://totp/{quote(self.config.issuer)}" + f"{':' + quote(user_identifier_info) if user_identifier_info is not None else ''}" + f"?secret={response['secret']}&issuer={quote(self.config.issuer)}&digits=6" + f"&period={period if period is not None else self.config.default_period}" + ) + + return CreateDeviceOkResult( + device_name=response["deviceName"], + secret=response["secret"], + qr_code_string=qr_code_string, + ) + + async def update_device( + self, + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Dict[str, Any], + ) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError,]: + # Prepare the data for the API request + data = { + "userId": user_id, + "existingDeviceName": existing_device_name, + "newDeviceName": new_device_name, + } + + # Send a PUT request to update the device + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/totp/device"), + data, + user_context=user_context, + ) + + # Handle the response based on the status + if resp["status"] == "OK": + return UpdateDeviceOkResult() + elif resp["status"] == "UNKNOWN_DEVICE_ERROR": + return UnknownDeviceError() + elif resp["status"] == "DEVICE_ALREADY_EXISTS_ERROR": + return DeviceAlreadyExistsError() + else: + # Raise an exception for unknown errors + raise Exception("Unknown error") + + async def list_devices( + self, user_id: str, user_context: Dict[str, Any] + ) -> ListDevicesOkResult: + params = {"userId": user_id} + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/totp/device/list"), + params, + user_context=user_context, + ) + return ListDevicesOkResult( + devices=[ + Device( + name=device["name"], + period=device["period"], + skew=device["skew"], + verified=device["verified"], + ) + for device in response["devices"] + ] + ) + + async def remove_device( + self, user_id: str, device_name: str, user_context: Dict[str, Any] + ) -> RemoveDeviceOkResult: + data = {"userId": user_id, "deviceName": device_name} + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/totp/device/remove"), + data, + user_context=user_context, + ) + return RemoveDeviceOkResult(did_device_exist=response["didDeviceExist"]) + + async def verify_device( + self, + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + ]: + data = {"userId": user_id, "deviceName": device_name, "totp": totp} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/totp/device/verify"), + data, + user_context=user_context, + ) + if response["status"] == "OK": + return VerifyDeviceOkResult( + was_already_verified=response["wasAlreadyVerified"] + ) + elif response["status"] == "UNKNOWN_DEVICE_ERROR": + return UnknownDeviceError() + elif response["status"] == "INVALID_TOTP_ERROR": + return InvalidTOTPError( + current_number_of_failed_attempts=response[ + "currentNumberOfFailedAttempts" + ], + max_number_of_failed_attempts=response["maxNumberOfFailedAttempts"], + ) + elif response["status"] == "LIMIT_REACHED_ERROR": + return LimitReachedError( + retry_after_ms=response["retryAfterMs"], + ) + else: + raise Exception("Unknown error") + + async def verify_totp( + self, tenant_id: str, user_id: str, totp: str, user_context: Dict[str, Any] + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + ]: + data = {"userId": user_id, "totp": totp} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/totp/verify"), + data, + user_context=user_context, + ) + if response["status"] == "OK": + return VerifyTOTPOkResult() + elif response["status"] == "UNKNOWN_USER_ID_ERROR": + return UnknownUserIdError() + elif response["status"] == "INVALID_TOTP_ERROR": + return InvalidTOTPError( + current_number_of_failed_attempts=response[ + "currentNumberOfFailedAttempts" + ], + max_number_of_failed_attempts=response["maxNumberOfFailedAttempts"], + ) + elif response["status"] == "LIMIT_REACHED_ERROR": + return LimitReachedError( + retry_after_ms=response["retryAfterMs"], + ) + else: + raise Exception("Unknown error") diff --git a/supertokens_python/recipe/totp/syncio/__init__.py b/supertokens_python/recipe/totp/syncio/__init__.py new file mode 100644 index 000000000..2eaecf37f --- /dev/null +++ b/supertokens_python/recipe/totp/syncio/__init__.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Union, Optional + +from supertokens_python.async_to_sync_wrapper import sync + +from ..recipe import TOTPRecipe +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UnknownUserIdError, + UpdateDeviceOkResult, + UnknownDeviceError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, +) + + +def create_device( + user_id: str, + user_identifier_info: Optional[str] = None, + device_name: Optional[str] = None, + skew: Optional[int] = None, + period: Optional[int] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import create_device as async_func + + return sync( + async_func( + user_id, user_identifier_info, device_name, skew, period, user_context + ) + ) + + +def update_device( + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import update_device as async_func + + return sync( + async_func(user_id, existing_device_name, new_device_name, user_context) + ) + + +def list_devices( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> ListDevicesOkResult: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import list_devices as async_func + + return sync(async_func(user_id, user_context)) + + +def remove_device( + user_id: str, + device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> RemoveDeviceOkResult: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import remove_device as async_func + + return sync(async_func(user_id, device_name, user_context)) + + +def verify_device( + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + VerifyDeviceOkResult, UnknownDeviceError, InvalidTOTPError, LimitReachedError +]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import verify_device as async_func + + return sync(async_func(tenant_id, user_id, device_name, totp, user_context)) + + +def verify_totp( + tenant_id: str, + user_id: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[VerifyTOTPOkResult, UnknownUserIdError, InvalidTOTPError, LimitReachedError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import verify_totp as async_func + + return sync(async_func(tenant_id, user_id, totp, user_context)) + + +init = TOTPRecipe.init diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py new file mode 100644 index 000000000..fae156fe8 --- /dev/null +++ b/supertokens_python/recipe/totp/types.py @@ -0,0 +1,221 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any, Callable, Optional +from typing_extensions import Literal +from .interfaces import RecipeInterface, APIInterface + +from supertokens_python.types import APIResponse + + +class OkResult(APIResponse): + def __init__(self): + self.status: Literal["OK"] = "OK" + + +class UserIdentifierInfoOkResult(OkResult): + def __init__(self, info: str): + super().__init__() + self.info: str = info + + def to_json(self) -> Dict[str, Any]: + raise NotImplementedError() + + +class UnknownUserIdError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_USER_ID_ERROR"] = "UNKNOWN_USER_ID_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class UserIdentifierInfoDoesNotExistError: + def __init__(self): + self.status: Literal[ + "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + ] = "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + + +class CreateDeviceOkResult(OkResult): + def __init__(self, device_name: str, secret: str, qr_code_string: str): + super().__init__() + self.device_name: str = device_name + self.secret: str = secret + self.qr_code_string: str = qr_code_string + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "deviceName": self.device_name, + "secret": self.secret, + "qrCodeString": self.qr_code_string, + } + + +class DeviceAlreadyExistsError(APIResponse): + def __init__(self): + self.status: Literal[ + "DEVICE_ALREADY_EXISTS_ERROR" + ] = "DEVICE_ALREADY_EXISTS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class UpdateDeviceOkResult(OkResult): + def __init__(self): + super().__init__() + + def to_json(self) -> Dict[str, Any]: + raise NotImplementedError() + + +class UnknownDeviceError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_DEVICE_ERROR"] = "UNKNOWN_DEVICE_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class Device(APIResponse): + def __init__(self, name: str, period: int, skew: int, verified: bool): + self.name: str = name + self.period: int = period + self.skew: int = skew + self.verified: bool = verified + + def to_json(self) -> Dict[str, Any]: + return { + "name": self.name, + "period": self.period, + "skew": self.skew, + "verified": self.verified, + } + + +class ListDevicesOkResult(OkResult): + def __init__(self, devices: List[Device]): + super().__init__() + self.devices: List[Device] = devices + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "devices": [device.to_json() for device in self.devices], + } + + +class RemoveDeviceOkResult(OkResult): + def __init__(self, did_device_exist: bool): + super().__init__() + self.did_device_exist: bool = did_device_exist + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "didDeviceExist": self.did_device_exist, + } + + +class VerifyDeviceOkResult(OkResult): + def __init__( + self, + was_already_verified: bool, + ): + super().__init__() + self.was_already_verified: bool = was_already_verified + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "wasAlreadyVerified": self.was_already_verified, + } + + +class InvalidTOTPError(APIResponse): + def __init__( + self, current_number_of_failed_attempts: int, max_number_of_failed_attempts: int + ): + self.status: Literal["INVALID_TOTP_ERROR"] = "INVALID_TOTP_ERROR" + self.current_number_of_failed_attempts: int = current_number_of_failed_attempts + self.max_number_of_failed_attempts: int = max_number_of_failed_attempts + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "currentNumberOfFailedAttempts": self.current_number_of_failed_attempts, + "maxNumberOfFailedAttempts": self.max_number_of_failed_attempts, + } + + +class LimitReachedError(APIResponse): + def __init__(self, retry_after_ms: int): + self.status: Literal["LIMIT_REACHED_ERROR"] = "LIMIT_REACHED_ERROR" + self.retry_after_ms: int = retry_after_ms + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "retryAfterMs": self.retry_after_ms, + } + + +class VerifyTOTPOkResult(OkResult): + def __init__( + self, + ): + super().__init__() + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class OverrideConfig: + def __init__( + self, + functions: Optional[Callable[[RecipeInterface], RecipeInterface]] = None, + apis: Optional[Callable[[APIInterface], APIInterface]] = None, + ): + self.functions = functions + self.apis = apis + + +class TOTPConfig: + def __init__( + self, + issuer: Optional[str] = None, + default_skew: Optional[int] = None, + default_period: Optional[int] = None, + override: Optional[OverrideConfig] = None, + ): + self.issuer = issuer + self.default_skew = default_skew + self.default_period = default_period + self.override = override + + +class TOTPNormalisedConfig: + def __init__( + self, + issuer: str, + default_skew: int, + default_period: int, + override: OverrideConfig, + ): + self.issuer = issuer + self.default_skew = default_skew + self.default_period = default_period + self.override = override diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py new file mode 100644 index 000000000..4684fa7f8 --- /dev/null +++ b/supertokens_python/recipe/totp/utils.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Union + +from supertokens_python import AppInfo +from .types import TOTPConfig, TOTPNormalisedConfig, OverrideConfig + + +def validate_and_normalise_user_input( + app_info: AppInfo, config: Union[TOTPConfig, None] +) -> TOTPNormalisedConfig: + if config is None: + config = TOTPConfig() + + issuer = config.issuer if config.issuer is not None else app_info.app_name + default_skew = config.default_skew if config.default_skew is not None else 1 + default_period = config.default_period if config.default_period is not None else 30 + + if config.override is None: + override = OverrideConfig() + else: + override = OverrideConfig( + functions=config.override.functions, + apis=config.override.apis, + ) + + return TOTPNormalisedConfig( + issuer=issuer, + default_skew=default_skew, + default_period=default_period, + override=override, + ) From 35c7cd09c1cd0db18269140cb690187b9feccc56 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 26 Aug 2024 15:28:50 +0530 Subject: [PATCH 023/126] more changes --- supertokens_python/supertokens.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index f99119bff..eb283ead3 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -255,8 +255,10 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) - # from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe - # from supertokens_python.recipe.totp.recipe import TOTPRecipe + from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe, + ) + from supertokens_python.recipe.totp.recipe import TOTPRecipe multitenancy_found = False totp_found = False @@ -270,10 +272,10 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multitenancy_found = True elif recipe_module.get_recipe_id() == "usermetadata": user_metadata_found = True - # elif recipe_module.get_recipe_id() == MultiFactorAuthRecipe.recipe_id: - # multi_factor_auth_found = True - # elif recipe_module.get_recipe_id() == TOTPRecipe.recipe_id: - # totp_found = True + elif recipe_module.get_recipe_id() == MultiFactorAuthRecipe.recipe_id: + multi_factor_auth_found = True + elif recipe_module.get_recipe_id() == TOTPRecipe.recipe_id: + totp_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) @@ -283,9 +285,7 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) if totp_found and not multi_factor_auth_found: - raise_general_exception( - "Please initialize the MultiFactorAuth recipe to use TOTP." - ) + raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") if not user_metadata_found: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe From d5028e675f6db505b293d6f6c7dc34f924efa26e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 26 Aug 2024 20:28:34 +0530 Subject: [PATCH 024/126] multi tenancy recipe --- .../recipe/dashboard/interfaces.py | 3 - .../recipe/multitenancy/api/implementation.py | 30 ++++++- .../recipe/multitenancy/interfaces.py | 22 ++--- .../recipe/multitenancy/recipe.py | 2 + .../multitenancy/recipe_implementation.py | 22 +++-- tests/emailpassword/test_multitenancy.py | 6 +- tests/multitenancy/test_tenants_crud.py | 90 +++++++++++++------ tests/passwordless/test_mutlitenancy.py | 21 ++++- tests/test-server/multitenancy.py | 12 +-- tests/thirdparty/test_multitenancy.py | 12 +-- tests/userroles/test_multitenancy.py | 6 +- 11 files changed, 151 insertions(+), 75 deletions(-) diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 2e93b40f5..53df121c5 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -110,9 +110,6 @@ def __init__(self, tenant_config: TenantConfig): def to_json(self) -> Dict[str, Any]: return { "tenantId": self.tenant_config.tenant_id, - "emailPassword": {"enabled": self.tenant_config.email_password_enabled}, - "passwordless": {"enabled": self.tenant_config.passwordless_enabled}, - "thirdParty": {"enabled": self.tenant_config.third_party_enabled}, } diff --git a/supertokens_python/recipe/multitenancy/api/implementation.py b/supertokens_python/recipe/multitenancy/api/implementation.py index 38f599352..5806afb57 100644 --- a/supertokens_python/recipe/multitenancy/api/implementation.py +++ b/supertokens_python/recipe/multitenancy/api/implementation.py @@ -25,6 +25,7 @@ from supertokens_python.types import GeneralErrorResponse from ..interfaces import APIInterface, ThirdPartyProvider +from ...multifactorauth.utils import is_valid_first_factor class APIImplementation(APIInterface): @@ -80,12 +81,35 @@ async def login_methods_get( ThirdPartyProvider(provider_instance.id, provider_instance.config.name) ) + first_factors: List[str] = [] + if tenant_config.first_factors is not None: + first_factors = tenant_config.first_factors + elif api_options.static_first_factors is not None: + first_factors = api_options.static_first_factors + else: + first_factors = list(set(api_options.all_available_first_factors)) + + valid_first_factors: List[str] = [] + for factor_id in first_factors: + valid_res = await is_valid_first_factor(tenant_id, factor_id, user_context) + if valid_res == "OK": + valid_first_factors.append(factor_id) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise Exception("Tenant not found") + return LoginMethodsGetOkResult( email_password=LoginMethodEmailPassword( - tenant_config.email_password_enabled + enabled="emailpassword" in valid_first_factors + ), + passwordless=LoginMethodPasswordless( + enabled=any( + factor in valid_first_factors + for factor in ["otp-email", "otp-phone", "link-email", "link-phone"] + ) ), - passwordless=LoginMethodPasswordless(tenant_config.passwordless_enabled), third_party=LoginMethodThirdParty( - tenant_config.third_party_enabled, final_provider_list + enabled="thirdparty" in valid_first_factors, + providers=final_provider_list, ), + first_factors=valid_first_factors, ) diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 04873ef04..0cc0f223d 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -33,31 +33,16 @@ def __init__( self, tenant_id: str = "", third_party_providers: List[ProviderConfig] = [], - email_password_enabled: bool = False, - passwordless_enabled: bool = False, - third_party_enabled: bool = False, core_config: Dict[str, Any] = {}, first_factors: Optional[List[str]] = None, required_secondary_factors: Optional[List[str]] = None, ): self.tenant_id = tenant_id - self.email_password_enabled = email_password_enabled - self.passwordless_enabled = passwordless_enabled - self.third_party_enabled = third_party_enabled self.core_config = core_config self.first_factors = first_factors self.required_secondary_factors = required_secondary_factors self.third_party_providers = third_party_providers - def to_json(self) -> Dict[str, Any]: - res: Dict[str, Any] = {} - res["tenantId"] = self.tenant_id - res["emailPasswordEnabled"] = self.email_password_enabled - res["passwordlessEnabled"] = self.passwordless_enabled - res["thirdPartyEnabled"] = self.third_party_enabled - res["coreConfig"] = self.core_config - return res - class CreateOrUpdateTenantOkResult: status = "OK" @@ -216,6 +201,8 @@ def __init__( config: MultitenancyConfig, recipe_implementation: RecipeInterface, static_third_party_providers: List[ProviderInput], + all_available_first_factors: List[str], + static_first_factors: Optional[List[str]], ): self.request = request self.response = response @@ -223,6 +210,8 @@ def __init__( self.config = config self.recipe_implementation = recipe_implementation self.static_third_party_providers = static_third_party_providers + self.static_first_factors = static_first_factors + self.all_available_first_factors = all_available_first_factors class ThirdPartyProvider: @@ -277,11 +266,13 @@ def __init__( email_password: LoginMethodEmailPassword, passwordless: LoginMethodPasswordless, third_party: LoginMethodThirdParty, + first_factors: List[str], ): self.status = "OK" self.email_password = email_password self.passwordless = passwordless self.third_party = third_party + self.first_factors = first_factors def to_json(self) -> Dict[str, Any]: return { @@ -289,6 +280,7 @@ def to_json(self) -> Dict[str, Any]: "emailPassword": self.email_password.to_json(), "passwordless": self.passwordless.to_json(), "thirdParty": self.third_party.to_json(), + "firstFactors": self.first_factors, } diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 5cfea3ce2..1c46bcf78 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -125,6 +125,8 @@ async def handle_api_request( self.config, self.recipe_implementation, self.static_third_party_providers, + self.all_available_first_factors, + self.static_first_factors, ) return await handle_login_methods_api( self.api_implementation, diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 1e2ef5417..8919cecd5 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -105,11 +105,10 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfig: return TenantConfig( tenant_id=tenant["tenantId"], - email_password_enabled=tenant["emailPassword"]["enabled"], - passwordless_enabled=tenant["passwordless"]["enabled"], third_party_providers=providers, - third_party_enabled=tenant["thirdParty"]["enabled"], core_config=tenant["coreConfig"], + first_factors=tenant.get("firstFactors"), + required_secondary_factors=tenant.get("requiredSecondaryFactors"), ) @@ -131,10 +130,19 @@ async def create_or_update_tenant( user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: response = await self.querier.send_put_request( - NormalisedURLPath("/recipe/multitenancy/tenant"), + NormalisedURLPath("/recipe/multitenancy/tenant/v2"), { "tenantId": tenant_id, - **(config.to_json() if config is not None else {}), + "firstFactors": ( + config.first_factors + if config and config.first_factors is not None + else None + ), + "requiredSecondaryFactors": ( + config.required_secondary_factors + if config and config.required_secondary_factors is not None + else None + ), }, user_context=user_context, ) @@ -159,7 +167,7 @@ async def get_tenant( ) -> Optional[TenantConfig]: res = await self.querier.send_get_request( NormalisedURLPath( - f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant" + f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/v2" ), None, user_context=user_context, @@ -176,7 +184,7 @@ async def list_all_tenants( self, user_context: Dict[str, Any] ) -> ListAllTenantsOkResult: response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/multitenancy/tenant/list"), + NormalisedURLPath("/recipe/multitenancy/tenant/list/v2"), {}, user_context=user_context, ) diff --git a/tests/emailpassword/test_multitenancy.py b/tests/emailpassword/test_multitenancy.py index eaa87d5bd..b57683bc0 100644 --- a/tests/emailpassword/test_multitenancy.py +++ b/tests/emailpassword/test_multitenancy.py @@ -62,9 +62,9 @@ async def test_multitenancy_in_emailpassword(): setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant("t2", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant("t3", TenantConfig(first_factors=["emailpassword"])) user1 = await sign_up("t1", "test@example.com", "password1") user2 = await sign_up("t2", "test@example.com", "password2") diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 051dc01b8..7b49113d2 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -67,51 +67,84 @@ async def test_tenant_crud(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t2", + TenantConfig(first_factors=["otp-email, otp-phone, link-email, link-phone"]), + ) + await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) tenants = await list_all_tenants() - assert len(tenants.tenants) == 4 + assert len(tenants.tenants) == 3 t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.email_password_enabled is True - assert t1_config.passwordless_enabled is False - assert t1_config.third_party_enabled is False + assert t1_config.first_factors is not None + assert "emailpassword" in t1_config.first_factors + assert "otp-email" in t1_config.first_factors + assert "otp-phone" in t1_config.first_factors + assert "link-email" in t1_config.first_factors + assert "link-phone" in t1_config.first_factors + assert "thirdparty" in t1_config.first_factors assert t1_config.core_config == {} t2_config = await get_tenant("t2") assert t2_config is not None - assert t2_config.email_password_enabled is False - assert t2_config.passwordless_enabled is True - assert t2_config.third_party_enabled is False + assert t2_config.first_factors is not None + assert "emailpassword" in t2_config.first_factors + assert "otp-email" in t2_config.first_factors + assert "otp-phone" in t2_config.first_factors + assert "link-email" in t2_config.first_factors + assert "link-phone" in t2_config.first_factors + assert "thirdparty" in t2_config.first_factors assert t2_config.core_config == {} t3_config = await get_tenant("t3") assert t3_config is not None - assert t3_config.email_password_enabled is False - assert t3_config.passwordless_enabled is False - assert t3_config.third_party_enabled is True + assert t3_config.first_factors is not None + assert "emailpassword" in t3_config.first_factors + assert "otp-email" in t3_config.first_factors + assert "otp-phone" in t3_config.first_factors + assert "link-email" in t3_config.first_factors + assert "link-phone" in t3_config.first_factors + assert "thirdparty" in t3_config.first_factors assert t3_config.core_config == {} # update tenant1 to add passwordless: - await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True)) + await create_or_update_tenant( + "t1", + TenantConfig( + first_factors=[ + "otp-email", + "otp-phone", + "link-email", + "link-phone", + ] + ), + ) t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.email_password_enabled is True - assert t1_config.passwordless_enabled is True - assert t1_config.third_party_enabled is False + assert t1_config.first_factors is not None + assert "emailpassword" in t1_config.first_factors + assert "otp-email" in t1_config.first_factors + assert "otp-phone" in t1_config.first_factors + assert "link-email" in t1_config.first_factors + assert "link-phone" in t1_config.first_factors + assert "thirdparty" in t1_config.first_factors assert t1_config.core_config == {} # update tenant1 to add thirdparty: - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config is not None - assert t1_config.email_password_enabled is True - assert t1_config.passwordless_enabled is True - assert t1_config.third_party_enabled is True + assert t1_config.first_factors is not None + assert "emailpassword" in t1_config.first_factors + assert "otp-email" in t1_config.first_factors + assert "otp-phone" in t1_config.first_factors + assert "link-email" in t1_config.first_factors + assert "link-phone" in t1_config.first_factors + assert "thirdparty" in t1_config.first_factors + assert t1_config.core_config == {} assert t1_config.core_config == {} # delete tenant2: @@ -126,7 +159,7 @@ async def test_tenant_thirdparty_config(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) await create_or_update_third_party_config( "t1", config=ProviderConfig( @@ -253,9 +286,14 @@ async def test_user_association_and_disassociation_with_tenants(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t2", + TenantConfig( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) signup_response = await sign_up("public", "test@example.com", "password1") assert isinstance(signup_response, SignUpOkResult) diff --git a/tests/passwordless/test_mutlitenancy.py b/tests/passwordless/test_mutlitenancy.py index 1121b11da..623ee25c1 100644 --- a/tests/passwordless/test_mutlitenancy.py +++ b/tests/passwordless/test_mutlitenancy.py @@ -57,9 +57,24 @@ async def test_multitenancy_functions(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(passwordless_enabled=True)) + await create_or_update_tenant( + "t1", + TenantConfig( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t2", + TenantConfig( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t3", + TenantConfig( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) code1 = await create_code( tenant_id="t1", email="test@example.com", user_input_code="123456" diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index 5f9625c38..03bdd342f 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -23,9 +23,9 @@ def create_or_update_tenant(): # type: ignore user_context = data.get("userContext") config = TenantConfig( - email_password_enabled=config.get("emailPasswordEnabled"), - passwordless_enabled=config.get("passwordlessEnabled"), - third_party_enabled=config.get("thirdPartyEnabled"), + first_factors=config.get("firstFactors"), + required_secondary_factors=config.get("requiredSecondaryFactors"), + third_party_providers=config.get("thirdPartyProviders"), core_config=config.get("coreConfig"), ) @@ -62,9 +62,9 @@ def get_tenant(): # type: ignore { "status": "OK", "tenant": { - "emailPassword": {"enabled": response.email_password_enabled}, - "thirdParty": {"enabled": response.third_party_enabled}, - "passwordless": {"enabled": response.passwordless_enabled}, + "firstFactors": response.first_factors, + "requiredSecondaryFactors": response.required_secondary_factors, + "thirdPartyProviders": response.third_party_providers, "coreConfig": response.core_config, }, } diff --git a/tests/thirdparty/test_multitenancy.py b/tests/thirdparty/test_multitenancy.py index 8c205f04d..d507fd409 100644 --- a/tests/thirdparty/test_multitenancy.py +++ b/tests/thirdparty/test_multitenancy.py @@ -49,9 +49,9 @@ async def test_thirtyparty_multitenancy_functions(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant("t2", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) # sign up: user1a = await manually_create_or_update_user( @@ -157,9 +157,9 @@ async def test_get_provider(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant("t2", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) await create_or_update_third_party_config( "t1", diff --git a/tests/userroles/test_multitenancy.py b/tests/userroles/test_multitenancy.py index 54017f1ff..ddf02c6c1 100644 --- a/tests/userroles/test_multitenancy.py +++ b/tests/userroles/test_multitenancy.py @@ -56,9 +56,9 @@ async def test_multitenancy_in_user_roles(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant("t2", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant("t3", TenantConfig(first_factors=["emailpassword"])) user = await sign_up("public", "test@example.com", "password1") assert isinstance(user, SignUpOkResult) From b3b93f78f55d65ff24e16cd7fd9380c330fae6ae Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 28 Aug 2024 00:00:28 +0530 Subject: [PATCH 025/126] adds auth utils --- .pylintrc | 3 +- supertokens_python/auth_utils.py | 952 ++++++++++++++++++ .../recipe/multitenancy/asyncio/__init__.py | 10 +- .../recipe/multitenancy/interfaces.py | 6 +- .../multitenancy/recipe_implementation.py | 15 +- .../recipe/multitenancy/syncio/__init__.py | 9 +- supertokens_python/recipe/session/__init__.py | 2 + supertokens_python/recipe/session/recipe.py | 4 + supertokens_python/recipe/session/utils.py | 8 + tests/multitenancy/test_tenants_crud.py | 13 +- tests/userroles/test_multitenancy.py | 7 +- 11 files changed, 1004 insertions(+), 25 deletions(-) create mode 100644 supertokens_python/auth_utils.py diff --git a/.pylintrc b/.pylintrc index 814178ce3..f9b838184 100644 --- a/.pylintrc +++ b/.pylintrc @@ -122,7 +122,8 @@ disable=raw-checker-failed, consider-using-f-string, consider-using-in, no-else-return, - no-self-use + no-self-use, + no-else-raise # Enable the message, report, category or checker with the given id(s). You can diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py new file mode 100644 index 000000000..8e4ce9015 --- /dev/null +++ b/supertokens_python/auth_utils.py @@ -0,0 +1,952 @@ +from typing import Awaitable, Callable, Dict, Any, Optional, Union, List +from typing_extensions import Literal +from supertokens_python.framework import BaseRequest +from supertokens_python.recipe.accountlinking import ( + AccountInfoWithRecipeIdAndUserId, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.accountlinking.utils import ( + recipe_init_defined_should_do_automatic_account_linking, +) +from supertokens_python.recipe.multifactorauth.asyncio import ( + mark_factor_as_complete_in_session, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.utils import ( + is_valid_first_factor, + update_and_get_mfa_related_info_in_session, +) +from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant +from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.types import ( + AccountInfo, + AccountLinkingUser, + LoginMethod, + ThirdPartyInfo, +) +from supertokens_python.recipe.accountlinking.interfaces import ( + RecipeUserId, +) +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.emailverification import ( + EmailVerificationClaim, +) +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.utils import log_debug_message +from .asyncio import get_user + + +class OkResponse: + status: Literal["OK"] + valid_factor_ids: List[str] + is_first_factor: bool + + def __init__(self, valid_factor_ids: List[str], is_first_factor: bool): + self.status = "OK" + self.valid_factor_ids = valid_factor_ids + self.is_first_factor = is_first_factor + + +class SignUpNotAllowedResponse: + status: Literal["SIGN_UP_NOT_ALLOWED"] + + +class SignInNotAllowedResponse: + status: Literal["SIGN_IN_NOT_ALLOWED"] + + +class LinkingToSessionUserFailedResponse: + status: Literal["LINKING_TO_SESSION_USER_FAILED"] + reason: Literal[ + "EMAIL_VERIFICATION_REQUIRED", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "INPUT_USER_IS_NOT_A_PRIMARY_USER", + ] + + def __init__( + self, + reason: Literal[ + "EMAIL_VERIFICATION_REQUIRED", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "INPUT_USER_IS_NOT_A_PRIMARY_USER", + ], + ): + self.reason = reason + + +async def pre_auth_checks( + authenticating_account_info: AccountInfoWithRecipeId, + authenticating_user: Union[AccountLinkingUser, None], + tenant_id: str, + factor_ids: List[str], + is_sign_up: bool, + is_verified: bool, + sign_in_verifies_login_method: bool, + skip_session_user_update_in_core: bool, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], +) -> Union[ + OkResponse, + SignUpNotAllowedResponse, + SignInNotAllowedResponse, + LinkingToSessionUserFailedResponse, +]: + valid_factor_ids: List[str] = [] + + if len(factor_ids) == 0: + raise Exception( + "This should never happen: empty factorIds array passed to preSignInChecks" + ) + + log_debug_message("preAuthChecks checking auth types") + auth_type_info = await check_auth_type_and_linking_status( + session, + authenticating_account_info, + authenticating_user, + skip_session_user_update_in_core, + user_context, + ) + if auth_type_info.status != "OK": + log_debug_message( + f"preAuthChecks returning {auth_type_info.status} from checkAuthType results" + ) + return auth_type_info + + if auth_type_info.is_first_factor: + log_debug_message("preAuthChecks getting valid first factors") + valid_first_factors = ( + await filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids, tenant_id, session is not None, user_context + ) + ) + valid_factor_ids = valid_first_factors + else: + assert isinstance( + auth_type_info, + (OkSecondFactorNotLinkedResponse, OkSecondFactorLinkedResponse), + ) + assert session is not None + log_debug_message("preAuthChecks getting valid secondary factors") + valid_factor_ids = ( + await filter_out_invalid_second_factors_or_throw_if_all_are_invalid( + factor_ids, + auth_type_info.input_user_already_linked_to_session_user, + auth_type_info.session_user, + session, + user_context, + ) + ) + + if not is_sign_up and authenticating_user is None: + raise Exception( + "This should never happen: preAuthChecks called with isSignUp: false, authenticatingUser: None" + ) + + if is_sign_up: + verified_in_session_user = not isinstance( + auth_type_info, OkFirstFactorResponse + ) and any( + lm.verified + and ( + lm.has_same_email_as(authenticating_account_info.email) + or lm.has_same_phone_number_as(authenticating_account_info.phone_number) + ) + for lm in auth_type_info.session_user.login_methods + ) + + log_debug_message("preAuthChecks checking if the user is allowed to sign up") + if not await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=authenticating_account_info, + is_verified=is_verified + or sign_in_verifies_login_method + or verified_in_session_user, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ): + return SignUpNotAllowedResponse() + elif authenticating_user is not None: + log_debug_message("preAuthChecks checking if the user is allowed to sign in") + if not await AccountLinkingRecipe.get_instance().is_sign_in_allowed( + user=authenticating_user, + account_info=authenticating_account_info, + sign_in_verifies_login_method=sign_in_verifies_login_method, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ): + return SignInNotAllowedResponse() + + log_debug_message("preAuthChecks returning OK") + return OkResponse( + valid_factor_ids=valid_factor_ids, + is_first_factor=auth_type_info.is_first_factor, + ) + + +class PostAuthChecksOkResponse: + status: Literal["OK"] + session: SessionContainer + user: AccountLinkingUser + + def __init__( + self, status: Literal["OK"], session: SessionContainer, user: AccountLinkingUser + ): + self.status = status + self.session = session + self.user = user + + +class PostAuthChecksSignInNotAllowedResponse: + status: Literal["SIGN_IN_NOT_ALLOWED"] + + +async def post_auth_checks( + authenticated_user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + is_sign_up: bool, + factor_id: str, + session: Union[SessionContainer, None], + tenant_id: str, + user_context: Dict[str, Any], + request: BaseRequest, +) -> Union[PostAuthChecksOkResponse, PostAuthChecksSignInNotAllowedResponse]: + log_debug_message( + f"postAuthChecks called {'with' if session is not None else 'without'} a session to " + f"{'sign up' if is_sign_up else 'sign in'} with {factor_id}" + ) + + mfa_instance = MultiFactorAuthRecipe.get_instance() + + resp_session = session + if session is not None: + authenticated_user_linked_to_session_user = any( + lm.recipe_user_id.get_as_string() + == session.get_recipe_user_id(user_context).get_as_string() + for lm in authenticated_user.login_methods + ) + if authenticated_user_linked_to_session_user: + log_debug_message("postAuthChecks session and input user got linked") + if mfa_instance is not None: + log_debug_message("postAuthChecks marking factor as completed") + # if the authenticating user is linked to the current session user (it means that the factor got set up or completed), + # we mark it as completed in the session. + assert resp_session is not None + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + else: + log_debug_message("postAuthChecks checking overwriteSessionDuringSignInUp") + # If the new user wasn't linked to the current one, we check the config and overwrite the session if required + # Note: we could also get here if MFA is enabled, but the app didn't want to link the user to the session user. + # This is intentional, since the MFA and overwriteSessionDuringSignInUp configs should work independently. + overwrite_session_during_sign_in_up = ( + SessionRecipe.get_instance().config.overwrite_session_during_sign_in_up + ) + if overwrite_session_during_sign_in_up: + resp_session = await create_new_session( + request, tenant_id, recipe_user_id, {}, {}, user_context + ) + if mfa_instance is not None: + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + else: + log_debug_message("postAuthChecks creating session for first factor sign in/up") + # If there is no input session, we do not need to do anything other checks and create a new session + resp_session = await create_new_session( + request, tenant_id, recipe_user_id, {}, {}, user_context + ) + + # Here we can always mark the factor as completed, since we just created the session + if mfa_instance is not None: + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + + assert resp_session is not None + return PostAuthChecksOkResponse( + status="OK", session=resp_session, user=authenticated_user + ) + + +class AuthenticatingUserInfo: + def __init__( + self, user: AccountLinkingUser, login_method: Union[LoginMethod, None] + ): + self.user = user + self.login_method = login_method + + +async def get_authenticating_user_and_add_to_current_tenant_if_required( + recipe_id: str, + email: Optional[str], + phone_number: Optional[str], + third_party: Optional[ThirdPartyInfo], + tenant_id: str, + session: Optional[SessionContainer], + check_credentials_on_tenant: Callable[[str], Awaitable[bool]], + user_context: Dict[str, Any], +) -> Optional[AuthenticatingUserInfo]: + i = 0 + while i < 300: + account_info = { + "email": email, + "phoneNumber": phone_number, + "thirdParty": third_party, + } + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired called with {account_info}" + ) + existing_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo( + email=email, phone_number=phone_number, third_party=third_party + ), + do_union_of_account_info=True, + user_context=user_context, + ) + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired got {len(existing_users)} users from the core resp" + ) + users_with_matching_login_methods = [ + AuthenticatingUserInfo( + user=user, + login_method=next( + ( + lm + for lm in user.login_methods + if lm.recipe_id == recipe_id + and ( + lm.has_same_email_as(email) + or lm.has_same_phone_number_as(phone_number) + or lm.has_same_third_party_info_as(third_party) + ) + ), + None, + ), + ) + for user in existing_users + ] + users_with_matching_login_methods = [ + u for u in users_with_matching_login_methods if u.login_method is not None + ] + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired got {len(users_with_matching_login_methods)} users with matching login methods" + ) + if len(users_with_matching_login_methods) > 1: + raise Exception( + "You have found a bug. Please report it on https://github.com/supertokens/supertokens-node/issues" + ) + authenticating_user = ( + AuthenticatingUserInfo( + users_with_matching_login_methods[0].user, + users_with_matching_login_methods[0].login_method, + ) + if users_with_matching_login_methods + else None + ) + + if authenticating_user is None and session is not None: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired checking session user" + ) + session_user = await get_user( + session.get_user_id(user_context), user_context + ) + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + if not session_user.is_primary_user: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired session user is non-primary so returning early without checking other tenants" + ) + return None + + matching_login_methods_from_session_user = [ + lm + for lm in session_user.login_methods + if lm.recipe_id == recipe_id + and ( + lm.has_same_email_as(email) + or lm.has_same_phone_number_as(phone_number) + or lm.has_same_third_party_info_as(third_party) + ) + ] + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session has {len(matching_login_methods_from_session_user)} matching login methods" + ) + + if any( + tenant_id in lm.tenant_ids + for lm in matching_login_methods_from_session_user + ): + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session has {len(matching_login_methods_from_session_user)} matching login methods" + ) + return AuthenticatingUserInfo( + user=session_user, + login_method=next( + lm + for lm in matching_login_methods_from_session_user + if tenant_id in lm.tenant_ids + ), + ) + + go_to_retry = False + for lm in matching_login_methods_from_session_user: + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session checking credentials on {lm.tenant_ids[0]}" + ) + if await check_credentials_on_tenant(lm.tenant_ids[0]): + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired associating user from {lm.tenant_ids[0]} with current tenant" + ) + associate_res = await associate_user_to_tenant( + tenant_id, lm.recipe_user_id, user_context + ) + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired associating returned {associate_res.status}" + ) + if associate_res.status == "OK": + lm.tenant_ids.append(tenant_id) + return AuthenticatingUserInfo( + user=session_user, login_method=lm + ) + if associate_res.status in [ + "UNKNOWN_USER_ID_ERROR", + "EMAIL_ALREADY_EXISTS_ERROR", + "PHONE_NUMBER_ALREADY_EXISTS_ERROR", + "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR", + ]: + go_to_retry = True + break + if associate_res.status == "ASSOCIATION_NOT_ALLOWED_ERROR": + raise UnauthorisedError( + "Session user not associated with the session tenant" + ) + if go_to_retry: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired retrying" + ) + i += 1 + continue + return authenticating_user + raise Exception( + "This should never happen: ran out of retries for getAuthenticatingUserAndAddToCurrentTenantIfRequired" + ) + + +class OkFirstFactorResponse: + status: Literal["OK"] + is_first_factor: Literal[True] + + +class OkSecondFactorLinkedResponse: + status: Literal["OK"] + is_first_factor: Literal[False] + input_user_already_linked_to_session_user: Literal[True] + session_user: AccountLinkingUser + + def __init__(self, session_user: AccountLinkingUser): + self.session_user = session_user + + +class OkSecondFactorNotLinkedResponse: + status: Literal["OK"] + is_first_factor: Literal[False] + input_user_already_linked_to_session_user: Literal[False] + session_user: AccountLinkingUser + linking_to_session_user_requires_verification: bool + + def __init__( + self, + session_user: AccountLinkingUser, + linking_to_session_user_requires_verification: bool, + ): + self.session_user = session_user + self.linking_to_session_user_requires_verification = ( + linking_to_session_user_requires_verification + ) + + +async def check_auth_type_and_linking_status( + session: Union[SessionContainer, None], + account_info: AccountInfoWithRecipeId, + input_user: Union[AccountLinkingUser, None], + skip_session_user_update_in_core: bool, + user_context: Dict[str, Any], +) -> Union[ + OkFirstFactorResponse, + OkSecondFactorLinkedResponse, + OkSecondFactorNotLinkedResponse, + LinkingToSessionUserFailedResponse, +]: + log_debug_message("check_auth_type_and_linking_status called") + session_user: Union[AccountLinkingUser, None] = None + if session is None: + log_debug_message( + "check_auth_type_and_linking_status returning first factor because there is no session" + ) + return OkFirstFactorResponse() + else: + if not recipe_init_defined_should_do_automatic_account_linking(): + if MultiFactorAuthRecipe.get_instance() is not None: + raise Exception( + "Please initialise the account linking recipe and define should_do_automatic_account_linking to enable MFA" + ) + else: + return OkFirstFactorResponse() + + if input_user is not None and input_user.id == session.get_user_id(): + log_debug_message( + "check_auth_type_and_linking_status returning secondary factor, session and input user are the same" + ) + return OkSecondFactorLinkedResponse( + session_user=input_user, + ) + + log_debug_message( + f"check_auth_type_and_linking_status loading session user, {input_user.id if input_user else None} === {session.get_user_id()}" + ) + session_user_result = await try_and_make_session_user_into_a_primary_user( + session, skip_session_user_update_in_core, user_context + ) + if session_user_result.status == "SHOULD_AUTOMATICALLY_LINK_FALSE": + return OkFirstFactorResponse() + elif ( + session_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkingToSessionUserFailedResponse( + reason="SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + session_user = session_user_result.user + + should_link = await AccountLinkingRecipe.get_instance().config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + session_user, + session, + session.get_tenant_id(), + user_context, + ) + log_debug_message( + f"check_auth_type_and_linking_status session user <-> input user should_do_automatic_account_linking returned {should_link}" + ) + + if isinstance(should_link, ShouldNotAutomaticallyLink): + return OkFirstFactorResponse() + else: + return OkSecondFactorNotLinkedResponse( + session_user=session_user, + linking_to_session_user_requires_verification=should_link.should_require_verification, + ) + + +class OkResponse2: + status: Literal["OK"] + user: AccountLinkingUser + + def __init__(self, user: AccountLinkingUser): + self.status = "OK" + self.user = user + + +async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id: str, + input_user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], +) -> Union[OkResponse2, LinkingToSessionUserFailedResponse,]: + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info called" + ) + + async def retry(): + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info retrying...." + ) + return await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=input_user, + session=session, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + auth_login_method = next( + ( + lm + for lm in input_user.login_methods + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + None, + ) + if auth_login_method is None: + raise Exception( + "This should never happen: the recipe_user_id and user is inconsistent in create_primary_user_id_or_link_by_account_info params" + ) + + auth_type_res = await check_auth_type_and_linking_status( + session, + AccountInfoWithRecipeId( + recipe_id=auth_login_method.recipe_id, + email=auth_login_method.email, + phone_number=auth_login_method.phone_number, + third_party=auth_login_method.third_party, + ), + input_user, + False, + user_context, + ) + + if not isinstance( + auth_type_res, + ( + OkFirstFactorResponse, + OkSecondFactorLinkedResponse, + OkSecondFactorNotLinkedResponse, + ), + ): + return LinkingToSessionUserFailedResponse(reason=auth_type_res.reason) + + if isinstance(auth_type_res, OkFirstFactorResponse): + if not recipe_init_defined_should_do_automatic_account_linking(): + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info skipping link by account info because this is a first factor auth and the app hasn't defined should_do_automatic_account_linking" + ) + return OkResponse2(user=input_user) + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info trying to link by account info because this is a first factor auth" + ) + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + input_user=input_user, + session=session, + tenant_id=tenant_id, + user_context=user_context, + ) + if link_res.status == "OK": + assert link_res.user is not None + return OkResponse2(user=link_res.user) + if link_res.status == "NO_LINK": + return OkResponse2(user=input_user) + return await retry() + + if isinstance(auth_type_res, OkSecondFactorLinkedResponse): + return OkResponse2(user=auth_type_res.session_user) + + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info trying to link by session info" + ) + session_linking_res = await try_linking_by_session( + session_user=auth_type_res.session_user, + authenticated_user=input_user, + auth_login_method=auth_login_method, + linking_to_session_user_requires_verification=auth_type_res.linking_to_session_user_requires_verification, + user_context=user_context, + ) + if isinstance(session_linking_res, LinkingToSessionUserFailedResponse): + if session_linking_res.reason == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return await retry() + else: + return session_linking_res + else: + return session_linking_res + + +class ShouldAutomaticallyLinkFalseResponse: + status: Literal["SHOULD_AUTOMATICALLY_LINK_FALSE"] + + def __init__(self): + self.status = "SHOULD_AUTOMATICALLY_LINK_FALSE" + + +class AccountInfoAlreadyAssociatedResponse: + status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] + + def __init__(self): + self.status = ( + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + +async def try_and_make_session_user_into_a_primary_user( + session: SessionContainer, + skip_session_user_update_in_core: bool, + user_context: Dict[str, Any], +) -> Union[ + OkResponse2, + ShouldAutomaticallyLinkFalseResponse, + AccountInfoAlreadyAssociatedResponse, +]: + log_debug_message("try_and_make_session_user_into_a_primary_user called") + session_user = await get_user(session.get_user_id(), user_context) + if session_user is None: + raise UnauthorisedError("Session user not found") + + if session_user.is_primary_user: + log_debug_message( + "try_and_make_session_user_into_a_primary_user session user already primary" + ) + return OkResponse2(user=session_user) + else: + log_debug_message( + "try_and_make_session_user_into_a_primary_user not primary user yet" + ) + + account_linking_instance = AccountLinkingRecipe.get_instance() + should_do_account_linking = ( + await account_linking_instance.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + session_user.login_methods[0] + ), + None, + session, + session.get_tenant_id(), + user_context, + ) + ) + log_debug_message( + f"try_and_make_session_user_into_a_primary_user should_do_account_linking: {should_do_account_linking}" + ) + + if isinstance(should_do_account_linking, ShouldAutomaticallyLink): + if skip_session_user_update_in_core: + return OkResponse2(user=session_user) + if ( + should_do_account_linking.should_require_verification + and not session_user.login_methods[0].verified + ): + if ( + await session.get_claim_value(EmailVerificationClaim, user_context) + ) is not False: + log_debug_message( + "try_and_make_session_user_into_a_primary_user updating emailverification status in session" + ) + await session.set_claim_value( + EmailVerificationClaim, False, user_context + ) + log_debug_message( + "try_and_make_session_user_into_a_primary_user throwing validation error" + ) + await session.assert_claims( + [EmailVerificationClaim.validators.is_verified()], user_context + ) + raise Exception( + "This should never happen: email verification claim validator passed after setting value to false" + ) + create_primary_user_res = await account_linking_instance.recipe_implementation.create_primary_user( + recipe_user_id=session_user.login_methods[0].recipe_user_id, + user_context=user_context, + ) + log_debug_message( + f"try_and_make_session_user_into_a_primary_user create_primary_user returned {create_primary_user_res.status}" + ) + if ( + create_primary_user_res.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + raise UnauthorisedError("Session user not found") + elif create_primary_user_res.status == "OK": + return OkResponse2(user=create_primary_user_res.user) + else: + return AccountInfoAlreadyAssociatedResponse() + else: + return ShouldAutomaticallyLinkFalseResponse() + + +async def try_linking_by_session( + linking_to_session_user_requires_verification: bool, + auth_login_method: LoginMethod, + authenticated_user: AccountLinkingUser, + session_user: AccountLinkingUser, + user_context: Dict[str, Any], +) -> Union[OkResponse2, LinkingToSessionUserFailedResponse,]: + log_debug_message("tryLinkingBySession called") + + session_user_has_verified_account_info = any( + ( + lm.has_same_email_as(auth_login_method.email) + or lm.has_same_phone_number_as(auth_login_method.phone_number) + ) + and lm.verified + for lm in session_user.login_methods + ) + + can_link_based_on_verification = ( + not linking_to_session_user_requires_verification + or auth_login_method.verified + or session_user_has_verified_account_info + ) + + if not can_link_based_on_verification: + return LinkingToSessionUserFailedResponse(reason="EMAIL_VERIFICATION_REQUIRED") + + link_accounts_result = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.link_accounts( + recipe_user_id=authenticated_user.login_methods[0].recipe_user_id, + primary_user_id=session_user.id, + user_context=user_context, + ) + ) + + if link_accounts_result.status == "OK": + log_debug_message( + "tryLinkingBySession successfully linked input user to session user" + ) + return OkResponse2(user=link_accounts_result.user) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - input user linked to another user" + ) + return LinkingToSessionUserFailedResponse( + reason="RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + elif link_accounts_result.status == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - INPUT_USER_IS_NOT_A_PRIMARY_USER, should retry" + ) + return LinkingToSessionUserFailedResponse( + reason="INPUT_USER_IS_NOT_A_PRIMARY_USER" + ) + else: + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - input user has another primary user it can be linked to" + ) + return LinkingToSessionUserFailedResponse( + reason="ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + +async def filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids: List[str], + tenant_id: str, + has_session: bool, + user_context: Dict[str, Any], +) -> List[str]: + valid_factor_ids: List[str] = [] + for _id in factor_ids: + valid_res = await is_valid_first_factor(tenant_id, _id, user_context) + + if valid_res == "TENANT_NOT_FOUND_ERROR": + if has_session: + raise UnauthorisedError("Tenant not found") + else: + raise Exception("Tenant not found error.") + elif valid_res == "OK": + valid_factor_ids.append(_id) + + if len(valid_factor_ids) == 0: + if not has_session: + raise UnauthorisedError( + "A valid session is required to authenticate with secondary factors" + ) + else: + raise_bad_input_exception( + "First factor sign in/up called for a non-first factor with an active session. This might indicate that you are trying to use this as a secondary factor, but disabled account linking." + ) + + return valid_factor_ids + + +async def filter_out_invalid_second_factors_or_throw_if_all_are_invalid( + factor_ids: List[str], + input_user_already_linked_to_session_user: bool, + session_user: AccountLinkingUser, + session: SessionContainer, + user_context: Dict[str, Any], +) -> List[str]: + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid called for {', '.join(factor_ids)}" + ) + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + if not input_user_already_linked_to_session_user: + factors_set_up_for_user_prom: Optional[List[str]] = None + mfa_info_prom = None + + async def get_factors_set_up_for_user(): + nonlocal factors_set_up_for_user_prom + if factors_set_up_for_user_prom is None: + factors_set_up_for_user_prom = await mfa_instance.recipe_implementation.get_factors_setup_for_user( + user=session_user, user_context=user_context + ) + return factors_set_up_for_user_prom + + async def get_mfa_requirements_for_auth(): + nonlocal mfa_info_prom + if mfa_info_prom is None: + mfa_info_prom = await update_and_get_mfa_related_info_in_session( + input_session=session, user_context=user_context + ) + return mfa_info_prom.mfa_requirements_for_auth + + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid checking if linking is allowed by the mfa recipe" + ) + caught_setup_factor_error: Optional[Exception] = None + valid_factor_ids: List[str] = [] + + for _id in factor_ids: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid checking assert_allowed_to_setup_factor_else_throw_invalid_claim_error" + ) + try: + await mfa_instance.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + factor_id=_id, + session=session, + factors_set_up_for_user=get_factors_set_up_for_user, + mfa_requirements_for_auth=get_mfa_requirements_for_auth, + user_context=user_context, + ) + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid {id} valid because assert_allowed_to_setup_factor_else_throw_invalid_claim_error passed" + ) + valid_factor_ids.append(_id) + except Exception as err: + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid assert_allowed_to_setup_factor_else_throw_invalid_claim_error failed for {id}" + ) + caught_setup_factor_error = err + + if len(valid_factor_ids) == 0: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid rethrowing error from assert_allowed_to_setup_factor_else_throw_invalid_claim_error because we found no valid factors" + ) + if caught_setup_factor_error is not None: + raise caught_setup_factor_error + else: + raise Exception("Should never come here") + + return valid_factor_ids + else: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid allowing all factors because it'll not create new link" + ) + return factor_ids + else: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid allowing all factors because MFA is not enabled" + ) + return factor_ids diff --git a/supertokens_python/recipe/multitenancy/asyncio/__init__.py b/supertokens_python/recipe/multitenancy/asyncio/__init__.py index cfc495c25..53051de60 100644 --- a/supertokens_python/recipe/multitenancy/asyncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/asyncio/__init__.py @@ -14,6 +14,8 @@ from typing import Any, Dict, Union, Optional, TYPE_CHECKING +from supertokens_python.types import RecipeUserId + from ..interfaces import ( TenantConfig, CreateOrUpdateTenantOkResult, @@ -112,7 +114,7 @@ async def delete_third_party_config( async def associate_user_to_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ) -> Union[ AssociateUserToTenantOkResult, @@ -127,13 +129,13 @@ async def associate_user_to_tenant( recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.associate_user_to_tenant( - tenant_id, user_id, user_context + tenant_id, recipe_user_id, user_context ) async def dissociate_user_from_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ) -> DisassociateUserFromTenantOkResult: if user_context is None: @@ -142,5 +144,5 @@ async def dissociate_user_from_tenant( recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.dissociate_user_from_tenant( - tenant_id, user_id, user_context + tenant_id, recipe_user_id, user_context ) diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 0cc0f223d..4789a2c9e 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Union, Callable, Awaitable, Optional, List -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import APIResponse, GeneralErrorResponse, RecipeUserId if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse @@ -171,7 +171,7 @@ async def delete_third_party_config( async def associate_user_to_tenant( self, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> Union[ AssociateUserToTenantOkResult, @@ -186,7 +186,7 @@ async def associate_user_to_tenant( async def dissociate_user_from_tenant( self, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> DisassociateUserFromTenantOkResult: pass diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 8919cecd5..8313c6491 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -22,6 +22,7 @@ AssociateUserToTenantThirdPartyUserAlreadyExistsError, DisassociateUserFromTenantOkResult, ) +from supertokens_python.types import RecipeUserId from .interfaces import ( RecipeInterface, @@ -242,7 +243,10 @@ async def delete_third_party_config( ) async def associate_user_to_tenant( - self, tenant_id: Optional[str], user_id: str, user_context: Dict[str, Any] + self, + tenant_id: Optional[str], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> Union[ AssociateUserToTenantOkResult, AssociateUserToTenantUnknownUserIdError, @@ -255,7 +259,7 @@ async def associate_user_to_tenant( f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user" ), { - "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), }, user_context=user_context, ) @@ -282,14 +286,17 @@ async def associate_user_to_tenant( raise Exception("Should never come here") async def dissociate_user_from_tenant( - self, tenant_id: Optional[str], user_id: str, user_context: Dict[str, Any] + self, + tenant_id: Optional[str], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> DisassociateUserFromTenantOkResult: response = await self.querier.send_post_request( NormalisedURLPath( f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user/remove" ), { - "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), }, user_context=user_context, ) diff --git a/supertokens_python/recipe/multitenancy/syncio/__init__.py b/supertokens_python/recipe/multitenancy/syncio/__init__.py index 27025f476..e152c880d 100644 --- a/supertokens_python/recipe/multitenancy/syncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/syncio/__init__.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.types import RecipeUserId if TYPE_CHECKING: from ..interfaces import TenantConfig, ProviderConfig @@ -95,7 +96,7 @@ def delete_third_party_config( def associate_user_to_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: @@ -103,12 +104,12 @@ def associate_user_to_tenant( from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant - return sync(associate_user_to_tenant(tenant_id, user_id, user_context)) + return sync(associate_user_to_tenant(tenant_id, recipe_user_id, user_context)) def dissociate_user_from_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: @@ -118,4 +119,4 @@ def dissociate_user_from_tenant( dissociate_user_from_tenant, ) - return sync(dissociate_user_from_tenant(tenant_id, user_id, user_context)) + return sync(dissociate_user_from_tenant(tenant_id, recipe_user_id, user_context)) diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index bd158dccb..fae46a384 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -52,6 +52,7 @@ def init( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, + overwrite_session_during_sign_in_up: Union[bool, None] = None, ) -> Callable[[AppInfo], RecipeModule]: return SessionRecipe.init( cookie_domain, @@ -67,4 +68,5 @@ def init( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, + overwrite_session_during_sign_in_up, ) diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 7e4eb4799..f2e221d89 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -93,6 +93,7 @@ def __init__( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, + overwrite_session_during_sign_in_up: Union[bool, None] = None, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( @@ -110,6 +111,7 @@ def __init__( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, + overwrite_session_during_sign_in_up, ) self.openid_recipe = OpenIdRecipe( recipe_id, @@ -310,6 +312,7 @@ def init( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, + overwrite_session_during_sign_in_up: Union[bool, None] = None, ): def func(app_info: AppInfo): if SessionRecipe.__instance is None: @@ -329,6 +332,7 @@ def func(app_info: AppInfo): use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, + overwrite_session_during_sign_in_up, ) return SessionRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 96e5c43a4..13f3d8dca 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -391,6 +391,7 @@ def __init__( use_dynamic_access_token_signing_key: bool, expose_access_token_to_frontend_in_cookie_based_auth: bool, jwks_refresh_interval_sec: int, + overwrite_session_during_sign_in_up: bool, ): self.session_expired_status_code = session_expired_status_code self.invalid_claim_status_code = invalid_claim_status_code @@ -411,6 +412,7 @@ def __init__( self.framework = framework self.mode = mode self.jwks_refresh_interval_sec = jwks_refresh_interval_sec + self.overwrite_session_during_sign_in_up = overwrite_session_during_sign_in_up def validate_and_normalise_user_input( @@ -434,6 +436,7 @@ def validate_and_normalise_user_input( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, + overwrite_session_during_sign_in_up: Union[bool, None] = None, ): _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: @@ -561,6 +564,11 @@ def anti_csrf_function( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, + ( + overwrite_session_during_sign_in_up + if overwrite_session_during_sign_in_up is not None + else False + ), ) diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 7b49113d2..627737649 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -19,6 +19,7 @@ from supertokens_python import init from supertokens_python.framework.fastapi import get_middleware from supertokens_python.recipe import emailpassword, multitenancy, session +from supertokens_python.types import RecipeUserId from tests.utils import ( setup_function, teardown_function, @@ -299,17 +300,17 @@ async def test_user_association_and_disassociation_with_tenants(): assert isinstance(signup_response, SignUpOkResult) user_id = signup_response.user.user_id - await associate_user_to_tenant("t1", user_id) - await associate_user_to_tenant("t2", user_id) - await associate_user_to_tenant("t3", user_id) + await associate_user_to_tenant("t1", RecipeUserId(user_id)) + await associate_user_to_tenant("t2", RecipeUserId(user_id)) + await associate_user_to_tenant("t3", RecipeUserId(user_id)) user = await get_user_by_id(user_id) assert user is not None assert len(user.tenant_ids) == 4 # public + 3 tenants - await dissociate_user_from_tenant("t1", user_id) - await dissociate_user_from_tenant("t2", user_id) - await dissociate_user_from_tenant("t3", user_id) + await dissociate_user_from_tenant("t1", RecipeUserId(user_id)) + await dissociate_user_from_tenant("t2", RecipeUserId(user_id)) + await dissociate_user_from_tenant("t3", RecipeUserId(user_id)) user = await get_user_by_id(user_id) assert user is not None diff --git a/tests/userroles/test_multitenancy.py b/tests/userroles/test_multitenancy.py index ddf02c6c1..30e557102 100644 --- a/tests/userroles/test_multitenancy.py +++ b/tests/userroles/test_multitenancy.py @@ -26,6 +26,7 @@ add_role_to_user, get_roles_for_user, ) +from supertokens_python.types import RecipeUserId from tests.utils import get_st_init_args from tests.utils import ( @@ -64,9 +65,9 @@ async def test_multitenancy_in_user_roles(): assert isinstance(user, SignUpOkResult) user_id = user.user.user_id - await associate_user_to_tenant("t1", user_id) - await associate_user_to_tenant("t2", user_id) - await associate_user_to_tenant("t3", user_id) + await associate_user_to_tenant("t1", RecipeUserId(user_id)) + await associate_user_to_tenant("t2", RecipeUserId(user_id)) + await associate_user_to_tenant("t3", RecipeUserId(user_id)) await create_new_role_or_add_permissions("role1", []) await create_new_role_or_add_permissions("role2", []) From a281813b13925574f2194364fc5fe08f57f52e48 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 28 Aug 2024 22:24:45 +0530 Subject: [PATCH 026/126] emailpassword recipe translation - wip --- supertokens_python/auth_utils.py | 11 +- .../emailpassword/api/implementation.py | 674 +++++++++++++++--- .../emailpassword/api/password_reset.py | 18 + .../recipe/emailpassword/api/signin.py | 28 +- .../recipe/emailpassword/api/signup.py | 25 +- .../recipe/emailpassword/asyncio/__init__.py | 245 ++++--- .../backward_compatibility/__init__.py | 17 +- .../recipe/emailpassword/interfaces.py | 245 ++++--- .../recipe/emailpassword/recipe.py | 171 ++++- .../emailpassword/recipe_implementation.py | 325 ++++++--- .../recipe/emailpassword/syncio/__init__.py | 208 ++++-- .../recipe/emailpassword/types.py | 33 +- 12 files changed, 1476 insertions(+), 524 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 8e4ce9015..d49357b0f 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -40,6 +40,9 @@ from supertokens_python.utils import log_debug_message from .asyncio import get_user +from typing import Dict, Union +from supertokens_python.utils import log_debug_message + class OkResponse: status: Literal["OK"] @@ -563,7 +566,7 @@ class OkResponse2: def __init__(self, user: AccountLinkingUser): self.status = "OK" - self.user = user + self.user: AccountLinkingUser = user async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( @@ -950,3 +953,9 @@ async def get_mfa_requirements_for_auth(): "filter_out_invalid_second_factors_or_throw_if_all_are_invalid allowing all factors because MFA is not enabled" ) return factor_ids + + +def is_fake_email(email: str) -> bool: + return email.endswith("@stfakeemail.supertokens.com") or email.endswith( + ".fakeemail.com" + ) # .fakeemail.com for older users diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index 0e0953db3..5498079c7 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -13,41 +13,62 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + get_authenticating_user_and_add_to_current_tenant_if_required, + is_fake_email, + post_auth_checks, + pre_auth_checks, +) from supertokens_python.logger import log_debug_message +from supertokens_python.recipe.accountlinking import ( + AccountInfoWithRecipeIdAndUserId, + ShouldNotAutomaticallyLink, +) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId from supertokens_python.recipe.emailpassword.constants import ( FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID, ) from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface, - CreateResetPasswordWrongUserIdError, + CreateResetPasswordOkResult, + EmailAlreadyExistsError, EmailExistsGetOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, GeneratePasswordResetTokenPostOkResult, - PasswordResetPostInvalidTokenResponse, + LinkingToSessionUserFailedError, + PasswordPolicyViolationError, PasswordResetPostOkResult, - ResetPasswordUsingTokenInvalidTokenError, + PasswordResetTokenInvalidError, + SignInOkResult, + SignInPostNotAllowedResponse, SignInPostOkResult, - SignInPostWrongCredentialsError, - SignInWrongCredentialsError, - SignUpEmailAlreadyExistsError, - SignUpPostEmailAlreadyExistsError, + SignUpOkResult, + SignUpPostNotAllowedResponse, SignUpPostOkResult, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + WrongCredentialsError, ) from supertokens_python.recipe.emailpassword.types import ( + EmailTemplateVars, FormField, - PasswordResetEmailTemplateVars, PasswordResetEmailTemplateVarsUser, ) +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.totp.types import UnknownUserIdError from ..utils import get_password_reset_link -from supertokens_python.recipe.session.asyncio import create_new_session -from supertokens_python.utils import find_first_occurrence_in_list if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import APIOptions -from supertokens_python.types import GeneralErrorResponse, RecipeUserId +from supertokens_python.types import AccountInfo, GeneralErrorResponse, RecipeUserId class APIImplementation(APIInterface): @@ -58,10 +79,23 @@ async def email_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[EmailExistsGetOkResult, GeneralErrorResponse]: - user = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + # Check if there exists an email password user with the same email + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, + ) + + email_password_user_exists = any( + any( + lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + for lm in user.login_methods + ) + for user in users ) - return EmailExistsGetOkResult(user is not None) + + return EmailExistsGetOkResult(exists=email_password_user_exists) async def generate_password_reset_token_post( self, @@ -69,52 +103,178 @@ async def generate_password_reset_token_post( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[GeneratePasswordResetTokenPostOkResult, GeneralErrorResponse]: - emailFormField = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: + email = next(f.value for f in form_fields if f.id == "email") + + async def generate_and_send_password_reset_token( + primary_user_id: str, recipe_user_id: Optional[RecipeUserId] + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: + user_id = ( + recipe_user_id.get_as_string() if recipe_user_id else primary_user_id + ) + response = ( + await api_options.recipe_implementation.create_reset_password_token( + tenant_id=tenant_id, + user_id=user_id, + email=email, + user_context=user_context, + ) + ) + if isinstance(response, UnknownUserIdError): + log_debug_message( + f"Password reset email not sent, unknown user id: {user_id}" + ) + return GeneratePasswordResetTokenPostOkResult() + + assert isinstance(response, CreateResetPasswordOkResult) + password_reset_link = get_password_reset_link( + app_info=api_options.app_info, + token=response.token, + tenant_id=tenant_id, + request=api_options.request, + user_context=user_context, + ) + + log_debug_message(f"Sending password reset email to {email}") + await api_options.email_delivery.ingredient_interface_impl.send_email( + EmailTemplateVars( + user=PasswordResetEmailTemplateVarsUser( + user_id=primary_user_id, + recipe_user_id=recipe_user_id, + email=email, + ), + password_reset_link=password_reset_link, + tenant_id=tenant_id, + ), + user_context=user_context, + ) + + return GeneratePasswordResetTokenPostOkResult() + + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, ) - if emailFormField is None: - raise Exception("Should never come here") - email = emailFormField.value - user = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + email_password_account = next( + ( + lm + for user in users + for lm in user.login_methods + if lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + ), + None, ) - if user is None: - return GeneratePasswordResetTokenPostOkResult() + primary_user_associated_with_email = next( + (u for u in users if u.is_primary_user), None + ) - token_result = ( - await api_options.recipe_implementation.create_reset_password_token( - user.user_id, tenant_id, user_context + if primary_user_associated_with_email is None: + if email_password_account is None: + log_debug_message( + f"Password reset email not sent, unknown user email: {email}" + ) + return GeneratePasswordResetTokenPostOkResult() + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, ) + + email_verified = any( + lm.has_same_email_as(email) and lm.verified + for lm in primary_user_associated_with_email.login_methods + ) + + has_other_email_or_phone = any( + (lm.email is not None and not lm.has_same_email_as(email)) + or lm.phone_number is not None + for lm in primary_user_associated_with_email.login_methods ) - if isinstance(token_result, CreateResetPasswordWrongUserIdError): - log_debug_message( - "Password reset email not sent, unknown user id: %s", user.user_id + if not email_verified and has_other_email_or_phone: + return GeneratePasswordResetTokenPostNotAllowedResponse( + "Reset password link was not created because of account take over risk. Please contact support. (ERR_CODE_001)" ) - return GeneratePasswordResetTokenPostOkResult() - password_reset_link = get_password_reset_link( - api_options.app_info, - token_result.token, + should_do_account_linking_response = await AccountLinkingRecipe.get_instance().config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + email_password_account + or AccountInfoWithRecipeId(email=email, recipe_id="emailpassword") + ), + primary_user_associated_with_email, + None, tenant_id, - api_options.request, user_context, ) - log_debug_message("Sending password reset email to %s", email) - send_email_input = PasswordResetEmailTemplateVars( - user=PasswordResetEmailTemplateVarsUser(user.user_id, user.email), - password_reset_link=password_reset_link, - tenant_id=tenant_id, - ) - await api_options.email_delivery.ingredient_interface_impl.send_email( - send_email_input, user_context + if email_password_account is None: + if isinstance( + should_do_account_linking_response, ShouldNotAutomaticallyLink + ): + log_debug_message( + "Password reset email not sent, since email password user didn't exist, and account linking not enabled" + ) + return GeneratePasswordResetTokenPostOkResult() + + is_sign_up_allowed = ( + await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=AccountInfoWithRecipeId( + email=email, recipe_id="emailpassword" + ), + is_verified=True, + session=None, + tenant_id=tenant_id, + user_context=user_context, + ) + ) + if is_sign_up_allowed: + return await generate_and_send_password_reset_token( + primary_user_associated_with_email.id, None + ) + else: + log_debug_message( + f"Password reset email not sent, is_sign_up_allowed returned false for email: {email}" + ) + return GeneratePasswordResetTokenPostOkResult() + + are_the_two_accounts_linked = any( + lm.recipe_user_id.get_as_string() + == email_password_account.recipe_user_id.get_as_string() + for lm in primary_user_associated_with_email.login_methods ) - return GeneratePasswordResetTokenPostOkResult() + if are_the_two_accounts_linked: + return await generate_and_send_password_reset_token( + primary_user_associated_with_email.id, + email_password_account.recipe_user_id, + ) + + if isinstance(should_do_account_linking_response, ShouldNotAutomaticallyLink): + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, + ) + + if not should_do_account_linking_response.should_require_verification: + return await generate_and_send_password_reset_token( + primary_user_associated_with_email.id, + email_password_account.recipe_user_id, + ) + + return await generate_and_send_password_reset_token( + primary_user_associated_with_email.id, email_password_account.recipe_user_id + ) async def password_reset_post( self, @@ -125,103 +285,413 @@ async def password_reset_post( user_context: Dict[str, Any], ) -> Union[ PasswordResetPostOkResult, - PasswordResetPostInvalidTokenResponse, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, GeneralErrorResponse, ]: - new_password_for_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields - ) - if new_password_for_field is None: - raise Exception("Should never come here") - new_password = new_password_for_field.value + async def mark_email_as_verified(recipe_user_id: RecipeUserId, email: str): + email_verification_instance = ( + EmailVerificationRecipe.get_instance_optional() + ) + if email_verification_instance: + token_response = await email_verification_instance.recipe_implementation.create_email_verification_token( + tenant_id=tenant_id, + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, + ) + + if token_response.status == "OK": + await email_verification_instance.recipe_implementation.verify_email_using_token( + tenant_id=tenant_id, + token=token_response.token, + attempt_account_linking=False, + user_context=user_context, + ) + + async def do_update_password_and_verify_email_and_try_link_if_not_primary( + recipe_user_id: RecipeUserId, + ): + update_response = ( + await api_options.recipe_implementation.update_email_or_password( + tenant_id_for_password_policy=tenant_id, + email=None, + recipe_user_id=recipe_user_id, + password=new_password, + apply_password_policy=None, + user_context=user_context, + ) + ) + + if isinstance(update_response, EmailAlreadyExistsError) or isinstance( + update_response, UpdateEmailOrPasswordEmailChangeNotAllowedError + ): + raise Exception("Should never happen") + if isinstance(update_response, UnknownUserIdError): + return PasswordResetTokenInvalidError() + elif isinstance(update_response, PasswordPolicyViolationError): + return update_response + else: + await mark_email_as_verified( + recipe_user_id, email_for_whom_token_was_generated + ) + updated_user_after_email_verification = await get_user( + recipe_user_id.get_as_string(), user_context + ) + if updated_user_after_email_verification is None: + raise Exception( + "Should never happen - user deleted after during password reset" + ) + + if updated_user_after_email_verification.is_primary_user: + return PasswordResetPostOkResult( + user=updated_user_after_email_verification, + email=email_for_whom_token_was_generated, + ) + + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=updated_user_after_email_verification, + session=None, + user_context=user_context, + ) + user_after_we_tried_linking = ( + link_res.user + if link_res.status == "OK" + else updated_user_after_email_verification + ) + + assert user_after_we_tried_linking is not None - result = await api_options.recipe_implementation.reset_password_using_token( - token, new_password, tenant_id, user_context + return PasswordResetPostOkResult( + user=user_after_we_tried_linking, + email=email_for_whom_token_was_generated, + ) + + new_password = next(f.value for f in form_fields if f.id == "password") + + token_consumption_response = ( + await api_options.recipe_implementation.consume_password_reset_token( + token=token, + tenant_id=tenant_id, + user_context=user_context, + ) ) - if isinstance(result, ResetPasswordUsingTokenInvalidTokenError): - return PasswordResetPostInvalidTokenResponse() + if isinstance(token_consumption_response, PasswordResetTokenInvalidError): + return PasswordResetTokenInvalidError() + + user_id_for_whom_token_was_generated = token_consumption_response.user_id + email_for_whom_token_was_generated = token_consumption_response.email - return PasswordResetPostOkResult(result.user_id) + existing_user = await get_user(token_consumption_response.user_id, user_context) + + if existing_user is None: + return PasswordResetTokenInvalidError() + + if existing_user.is_primary_user: + email_password_user_is_linked_to_existing_user = any( + lm.recipe_user_id.get_as_string() + == user_id_for_whom_token_was_generated + and lm.recipe_id == "emailpassword" + for lm in existing_user.login_methods + ) + + if email_password_user_is_linked_to_existing_user: + return await do_update_password_and_verify_email_and_try_link_if_not_primary( + RecipeUserId(user_id_for_whom_token_was_generated) + ) + else: + create_user_response = ( + await api_options.recipe_implementation.create_new_recipe_user( + tenant_id=tenant_id, + email=token_consumption_response.email, + password=new_password, + user_context=user_context, + ) + ) + if isinstance(create_user_response, EmailAlreadyExistsError): + return PasswordResetTokenInvalidError() + else: + await mark_email_as_verified( + create_user_response.user.login_methods[0].recipe_user_id, + token_consumption_response.email, + ) + updated_user = await get_user( + create_user_response.user.id, + user_context, + ) + if updated_user is None: + raise Exception( + "Should never happen - user deleted after during password reset" + ) + create_user_response.user = updated_user + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=create_user_response.user, + session=None, + user_context=user_context, + ) + user_after_linking = ( + link_res.user + if link_res.status == "OK" + else create_user_response.user + ) + assert user_after_linking is not None + return PasswordResetPostOkResult( + user=user_after_linking, + email=token_consumption_response.email, + ) + else: + return ( + await do_update_password_and_verify_email_and_try_link_if_not_primary( + RecipeUserId(user_id_for_whom_token_was_generated) + ) + ) async def sign_in_post( self, form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignInPostOkResult, SignInPostWrongCredentialsError, GeneralErrorResponse + SignInPostOkResult, + WrongCredentialsError, + SignInPostNotAllowedResponse, + GeneralErrorResponse, ]: - password_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields + error_code_map = { + "SIGN_IN_NOT_ALLOWED": "Cannot sign in due to security reasons. Please try resetting your password, use a different login method or contact support. (ERR_CODE_008)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_009)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_010)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_011)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_012)", + }, + } + + email = next(f.value for f in form_fields if f.id == FORM_FIELD_EMAIL_ID) + password = next(f.value for f in form_fields if f.id == FORM_FIELD_PASSWORD_ID) + + recipe_id = "emailpassword" + + async def check_credentials_on_tenant(tenant_id: str) -> bool: + verify_result = await api_options.recipe_implementation.verify_credentials( + email=email, + password=password, + tenant_id=tenant_id, + user_context=user_context, + ) + return isinstance(verify_result, SignInOkResult) + + if is_fake_email(email) and session is None: + return WrongCredentialsError() + + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + email=email, + phone_number=None, + third_party=None, + user_context=user_context, + recipe_id=recipe_id, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials_on_tenant, + ) ) - if password_form_field is None: - raise Exception("Should never come here") - password = password_form_field.value - email_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields + is_verified = ( + authenticating_user is not None + and authenticating_user.login_method is not None + and authenticating_user.login_method.verified ) - if email_form_field is None: - raise Exception("Should never come here") - email = email_form_field.value - result = await api_options.recipe_implementation.sign_in( - email, password, tenant_id, user_context + if authenticating_user is None: + return WrongCredentialsError() + + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id=recipe_id, + email=email, + ), + factor_ids=["emailpassword"], + is_sign_up=False, + authenticating_user=authenticating_user.user, + is_verified=is_verified, + sign_in_verifies_login_method=False, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + user_context=user_context, + session=session, ) - if isinstance(result, SignInWrongCredentialsError): - return SignInPostWrongCredentialsError() + if pre_auth_checks_result.status != "OK": + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + raise Exception("Should never happen") + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInPostNotAllowedResponse(reason=reason) + + if is_fake_email(email) and pre_auth_checks_result.is_first_factor: + return WrongCredentialsError() - user = result.user - session = await create_new_session( + sign_in_response = await api_options.recipe_implementation.sign_in( + email=email, + password=password, + session=session, tenant_id=tenant_id, - request=api_options.request, - recipe_user_id=RecipeUserId(user.user_id), - access_token_payload={}, - session_data_in_database={}, user_context=user_context, ) - return SignInPostOkResult(user, session) + + if isinstance(sign_in_response, WrongCredentialsError): + return WrongCredentialsError() + if isinstance(sign_in_response, LinkingToSessionUserFailedError): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[sign_in_response.reason] + return SignInPostNotAllowedResponse(reason=reason) + + post_auth_checks_result = await post_auth_checks( + authenticated_user=sign_in_response.user, + recipe_user_id=sign_in_response.recipe_user_id, + is_sign_up=False, + factor_id="emailpassword", + session=session, + tenant_id=tenant_id, + user_context=user_context, + request=api_options.request, + ) + + if post_auth_checks_result.status != "OK": + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInPostNotAllowedResponse(reason) + + return SignInPostOkResult( + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, + ) async def sign_up_post( self, form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignUpPostOkResult, SignUpPostEmailAlreadyExistsError, GeneralErrorResponse + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, ]: - password_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields - ) - if password_form_field is None: - raise Exception("Should never come here") - password = password_form_field.value + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign up due to security reasons. Please try logging in, use a different login method or contact support. (ERR_CODE_007)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_013)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_014)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_015)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_016)", + }, + } - email_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields - ) - if email_form_field is None: - raise Exception("Should never come here") - email = email_form_field.value + email = next(f.value for f in form_fields if f.id == "email") + password = next(f.value for f in form_fields if f.id == "password") - result = await api_options.recipe_implementation.sign_up( - email, password, tenant_id, user_context + pre_auth_check_res = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="emailpassword", + email=email, + ), + factor_ids=["emailpassword"], + is_sign_up=True, + is_verified=is_fake_email(email), + sign_in_verifies_login_method=False, + skip_session_user_update_in_core=False, + authenticating_user=None, # since this is a sign up, this is None + tenant_id=tenant_id, + user_context=user_context, + session=session, ) - if isinstance(result, SignUpEmailAlreadyExistsError): - return SignUpPostEmailAlreadyExistsError() + if pre_auth_check_res.status == "SIGN_UP_NOT_ALLOWED": + conflicting_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo( + email=email, + ), + do_union_of_account_info=False, + user_context=user_context, + ) + if any( + any( + lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + for lm in u.login_methods + ) + for u in conflicting_users + ): + return EmailAlreadyExistsError() + + if pre_auth_check_res.status != "OK": + if isinstance(pre_auth_check_res, SignInNotAllowedResponse): + raise Exception("Should never happen") + if isinstance(pre_auth_check_res, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignUpPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_check_res.reason] + return SignUpPostNotAllowedResponse(reason=reason) - user = result.user - session = await create_new_session( + if is_fake_email(email) and pre_auth_check_res.is_first_factor: + # Fake emails cannot be used as a first factor + return EmailAlreadyExistsError() + + sign_up_response = await api_options.recipe_implementation.sign_up( tenant_id=tenant_id, + email=email, + password=password, + session=session, + user_context=user_context, + ) + + if isinstance(sign_up_response, EmailAlreadyExistsError): + return sign_up_response + if not isinstance(sign_up_response, SignUpOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[sign_up_response.reason] + return SignUpPostNotAllowedResponse(reason=reason) + + post_auth_checks_res = await post_auth_checks( + authenticated_user=sign_up_response.user, + recipe_user_id=sign_up_response.recipe_user_id, + is_sign_up=True, + factor_id="emailpassword", + session=session, request=api_options.request, - recipe_user_id=RecipeUserId(user.user_id), - access_token_payload={}, - session_data_in_database={}, + tenant_id=tenant_id, user_context=user_context, ) - return SignUpPostOkResult(user, session) + + if post_auth_checks_res.status != "OK": + # this will fail cause error_code_map doesn't have SIGN_IN_NOT_ALLOWED + # but that's ok, cause it should never come here for sign up anyway. + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignUpPostNotAllowedResponse(reason) + + return SignUpPostOkResult( + user=post_auth_checks_res.user, + session=post_auth_checks_res.session, + ) diff --git a/supertokens_python/recipe/emailpassword/api/password_reset.py b/supertokens_python/recipe/emailpassword/api/password_reset.py index 8ae7cbc7a..af72d7a4b 100644 --- a/supertokens_python/recipe/emailpassword/api/password_reset.py +++ b/supertokens_python/recipe/emailpassword/api/password_reset.py @@ -14,6 +14,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.recipe.emailpassword.exceptions import ( + raise_form_field_exception, +) + +from supertokens_python.recipe.emailpassword.interfaces import ( + PasswordPolicyViolationError, +) +from supertokens_python.recipe.emailpassword.types import ErrorFormField if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import ( @@ -55,4 +63,14 @@ async def handle_password_reset_api( response = await api_implementation.password_reset_post( form_fields, token, tenant_id, api_options, user_context ) + if isinstance(response, PasswordPolicyViolationError): + return raise_form_field_exception( + "Error in input formFields", + [ + ErrorFormField( + id="password", + error=response.failure_reason, + ) + ], + ) return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/api/signin.py b/supertokens_python/recipe/emailpassword/api/signin.py index 6dfd770f8..2fbfa2271 100644 --- a/supertokens_python/recipe/emailpassword/api/signin.py +++ b/supertokens_python/recipe/emailpassword/api/signin.py @@ -14,6 +14,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.recipe.emailpassword.interfaces import SignInPostOkResult + +from supertokens_python.recipe.session.asyncio import get_session if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import ( @@ -22,7 +25,10 @@ ) from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + send_200_response, +) from .utils import validate_form_fields_or_throw_error @@ -43,8 +49,26 @@ async def handle_sign_in_api( api_options.config.sign_in_feature.form_fields, form_fields_raw, tenant_id ) + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + response = await api_implementation.sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) + if isinstance(response, SignInPostOkResult): + return send_200_response( + get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + api_options.response, + ) + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index e8a1ce1f9..f7ad8ec26 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict from supertokens_python.recipe.emailpassword.interfaces import SignUpPostOkResult +from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.types import GeneralErrorResponse from ..exceptions import raise_form_field_exception @@ -28,7 +29,10 @@ ) from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + send_200_response, +) from .utils import validate_form_fields_or_throw_error @@ -49,12 +53,27 @@ async def handle_sign_up_api( api_options.config.sign_up_feature.form_fields, form_fields_raw, tenant_id ) + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + response = await api_implementation.sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) if isinstance(response, SignUpPostOkResult): - return send_200_response(response.to_json(), api_options.response) + return send_200_response( + get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + api_options.response, + ) if isinstance(response, GeneralErrorResponse): return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index a97f47367..ec01bcf0a 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -12,19 +12,30 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict, Union, Optional +from typing_extensions import Literal from supertokens_python import get_request_from_user_context +from supertokens_python.asyncio import get_user from supertokens_python.recipe.emailpassword import EmailPasswordRecipe +from supertokens_python.recipe.session import SessionContainer -from ..types import EmailTemplateVars, User +from ..types import EmailTemplateVars from ...multitenancy.constants import DEFAULT_TENANT_ID +from ....types import RecipeUserId from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordWrongUserIdError, - CreateResetPasswordLinkUnknownUserIdError, - CreateResetPasswordLinkOkResult, - SendResetPasswordEmailOkResult, - SendResetPasswordEmailUnknownUserIdError, + CreateResetPasswordOkResult, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, + UpdateEmailOrPasswordOkResult, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, + SignUpOkResult, + EmailAlreadyExistsError, + LinkingToSessionUserFailedError, + SignInOkResult, + WrongCredentialsError, ) from supertokens_python.recipe.emailpassword.utils import get_password_reset_link from supertokens_python.recipe.emailpassword.types import ( @@ -33,55 +44,57 @@ ) -async def update_email_or_password( - user_id: str, - email: Union[str, None] = None, - password: Union[str, None] = None, - apply_password_policy: Union[bool, None] = None, - tenant_id_for_password_policy: Optional[str] = None, - user_context: Union[None, Dict[str, Any]] = None, -): +async def sign_up( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.update_email_or_password( - user_id, - email, - password, - apply_password_policy, - tenant_id_for_password_policy or DEFAULT_TENANT_ID, - user_context, + return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_up( + email, password, tenant_id or DEFAULT_TENANT_ID, session, user_context ) -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[None, User]: +async def sign_in( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: if user_context is None: user_context = {} - return ( - await EmailPasswordRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id, user_context - ) + return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_in( + email, password, tenant_id or DEFAULT_TENANT_ID, session, user_context ) -async def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: +async def verify_credentials( + tenant_id: str, + email: str, + password: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.get_user_by_email( - email, tenant_id, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.verify_credentials( + email, password, tenant_id or DEFAULT_TENANT_ID, user_context ) async def create_reset_password_token( - tenant_id: str, user_id: str, user_context: Union[None, Dict[str, Any]] = None -): + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: if user_context is None: user_context = {} return await EmailPasswordRecipe.get_instance().recipe_implementation.create_reset_password_token( - user_id, tenant_id, user_context + user_id, email, tenant_id or DEFAULT_TENANT_ID, user_context ) @@ -89,91 +102,147 @@ async def reset_password_using_token( tenant_id: str, token: str, new_password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + PasswordPolicyViolationError, + PasswordResetTokenInvalidError, + UnknownUserIdError, +]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.reset_password_using_token( - token, new_password, tenant_id, user_context + consume_resp = await consume_password_reset_token(tenant_id, token, user_context) + if not isinstance(consume_resp, ConsumePasswordResetTokenOkResult): + return consume_resp + + result = await update_email_or_password( + recipe_user_id=RecipeUserId(consume_resp.user_id), + email=consume_resp.email, + password=new_password, + tenant_id_for_password_policy=tenant_id, + user_context=user_context, ) + if isinstance(result, EmailAlreadyExistsError) or isinstance( + result, UpdateEmailOrPasswordEmailChangeNotAllowedError + ): + raise Exception("Should never happen") -async def sign_in( - tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_in( - email, password, tenant_id, user_context - ) + return result -async def sign_up( +async def consume_password_reset_token( tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + token: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_up( - email, password, tenant_id, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.consume_password_reset_token( + token, tenant_id or DEFAULT_TENANT_ID, user_context ) -async def send_email( - input_: EmailTemplateVars, - user_context: Union[None, Dict[str, Any]] = None, -): +async def update_email_or_password( + recipe_user_id: RecipeUserId, + email: Optional[str] = None, + password: Optional[str] = None, + apply_password_policy: Optional[bool] = None, + tenant_id_for_password_policy: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( - input_, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.update_email_or_password( + recipe_user_id, + email, + password, + apply_password_policy, + tenant_id_for_password_policy or DEFAULT_TENANT_ID, + user_context, ) async def create_reset_password_link( - tenant_id: str, user_id: str, user_context: Optional[Dict[str, Any]] = None -): + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[str, UnknownUserIdError]: if user_context is None: user_context = {} - token = await create_reset_password_token(tenant_id, user_id, user_context) - if isinstance(token, CreateResetPasswordWrongUserIdError): - return CreateResetPasswordLinkUnknownUserIdError() + token = await create_reset_password_token(tenant_id, user_id, email, user_context) + if isinstance(token, UnknownUserIdError): + return token recipe_instance = EmailPasswordRecipe.get_instance() request = get_request_from_user_context(user_context) - return CreateResetPasswordLinkOkResult( - link=get_password_reset_link( - recipe_instance.get_app_info(), - token.token, - tenant_id, - request, - user_context, - ) + return get_password_reset_link( + recipe_instance.get_app_info(), + token.token, + tenant_id or DEFAULT_TENANT_ID, + request, + user_context, ) async def send_reset_password_email( - tenant_id: str, user_id: str, user_context: Optional[Dict[str, Any]] = None -): - link = await create_reset_password_link(tenant_id, user_id, user_context) - if isinstance(link, CreateResetPasswordLinkUnknownUserIdError): - return SendResetPasswordEmailUnknownUserIdError() + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[Literal["UNKNOWN_USER_ID_ERROR"], Literal["OK"]]: + if user_context is None: + user_context = {} + + user = await get_user(user_id, user_context) + if user is None: + return "UNKNOWN_USER_ID_ERROR" + + login_method = next( + ( + m + for m in user.login_methods + if m.recipe_id == "emailpassword" and m.has_same_email_as(email) + ), + None, + ) + if login_method is None: + return "UNKNOWN_USER_ID_ERROR" - user = await get_user_by_id(user_id, user_context) - assert user is not None + link = await create_reset_password_link(tenant_id, user_id, email, user_context) + if isinstance(link, UnknownUserIdError): + return "UNKNOWN_USER_ID_ERROR" + assert login_method.email is not None await send_email( PasswordResetEmailTemplateVars( - PasswordResetEmailTemplateVarsUser(user.user_id, user.email), - link.link, - tenant_id, + user=PasswordResetEmailTemplateVarsUser( + user_id=user.id, + email=login_method.email, + recipe_user_id=login_method.recipe_user_id, + ), + password_reset_link=link, + tenant_id=tenant_id or DEFAULT_TENANT_ID, ), user_context, ) - return SendResetPasswordEmailOkResult() + return "OK" + + +async def send_email( + input_: EmailTemplateVars, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await EmailPasswordRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( + input_, user_context + ) diff --git a/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py b/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py index ee295789e..ab256d780 100644 --- a/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py +++ b/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py @@ -24,13 +24,17 @@ EmailTemplateVars, RecipeInterface, ) -from supertokens_python.recipe.emailpassword.types import User +from supertokens_python.recipe.emailpassword.types import ( + PasswordResetEmailTemplateVarsUser, +) from supertokens_python.supertokens import AppInfo from supertokens_python.utils import handle_httpx_client_exceptions async def create_and_send_email_using_supertokens_service( - app_info: AppInfo, user: User, password_reset_url_with_token: str + app_info: AppInfo, + user: PasswordResetEmailTemplateVarsUser, + password_reset_url_with_token: str, ) -> None: if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): return @@ -66,19 +70,12 @@ async def send_email( template_vars: EmailTemplateVars, user_context: Dict[str, Any], ) -> None: - user = await self.recipe_interface_impl.get_user_by_id( - user_id=template_vars.user.id, user_context=user_context - ) - if user is None: - raise Exception("Should never come here") - # we add this here cause the user may have overridden the sendEmail function # to change the input email and if we don't do this, the input email # will get reset by the getUserById call above. - user.email = template_vars.user.email try: await create_and_send_email_using_supertokens_service( - self.app_info, user, template_vars.password_reset_link + self.app_info, template_vars.user, template_vars.password_reset_link ) except Exception: pass diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 3567754fc..549666f57 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -15,143 +15,179 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing_extensions import Literal from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.emailpassword.types import EmailTemplateVars from ...supertokens import AppInfo - -from ...types import APIResponse, GeneralErrorResponse +from ...types import APIResponse, GeneralErrorResponse, RecipeUserId if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer - from .types import FormField, User + from .types import FormField + from ...types import AccountLinkingUser from .utils import EmailPasswordConfig class SignUpOkResult: - def __init__(self, user: User): - self.user = user - - -class SignUpEmailAlreadyExistsError: - pass - + status: str = "OK" -class SignInOkResult: - def __init__(self, user: User): + def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): self.user = user + self.recipe_user_id = recipe_user_id -class SignInWrongCredentialsError: - pass - - -class CreateResetPasswordOkResult: - def __init__(self, token: str): - self.token = token +class EmailAlreadyExistsError(APIResponse): + status: str = "EMAIL_ALREADY_EXISTS_ERROR" + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} -class CreateResetPasswordWrongUserIdError: - pass +class LinkingToSessionUserFailedError: + def __init__( + self, + reason: Literal[ + "EMAIL_VERIFICATION_REQUIRED", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "INPUT_USER_IS_NOT_A_PRIMARY_USER", + ], + ): + self.reason = reason -class CreateResetPasswordLinkOkResult: - def __init__(self, link: str): - self.link = link +class SignInOkResult: + def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): + self.user = user + self.recipe_user_id = recipe_user_id -class CreateResetPasswordLinkUnknownUserIdError: - pass +class WrongCredentialsError(APIResponse): + status: str = "WRONG_CREDENTIALS_ERROR" -class SendResetPasswordEmailOkResult: - pass + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} -class SendResetPasswordEmailUnknownUserIdError: - pass +class CreateResetPasswordOkResult: + def __init__(self, token: str): + self.token = token -class ResetPasswordUsingTokenOkResult: - def __init__(self, user_id: Union[str, None]): +class ConsumePasswordResetTokenOkResult: + def __init__(self, email: str, user_id: str): + self.email = email self.user_id = user_id -class ResetPasswordUsingTokenInvalidTokenError: - pass +class PasswordResetTokenInvalidError(APIResponse): + status: str = "RESET_PASSWORD_INVALID_TOKEN_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} class UpdateEmailOrPasswordOkResult: pass -class UpdateEmailOrPasswordEmailAlreadyExistsError: +class UnknownUserIdError: pass -class UpdateEmailOrPasswordUnknownUserIdError: - pass - +class UpdateEmailOrPasswordEmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason -class UpdateEmailOrPasswordPasswordPolicyViolationError: - failure_reason: str +class PasswordPolicyViolationError(APIResponse): def __init__(self, failure_reason: str): self.failure_reason = failure_reason + def to_json(self) -> Dict[str, Any]: + return { + "status": "PASSWORD_POLICY_VIOLATED_ERROR", + "failureReason": self.failure_reason, + } + class RecipeInterface(ABC): def __init__(self): pass @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def sign_up( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], + ) -> Union[ + SignUpOkResult, + EmailAlreadyExistsError, + LinkingToSessionUserFailedError, + ]: pass @abstractmethod - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def create_new_recipe_user( + self, + email: str, + password: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[SignUpOkResult, EmailAlreadyExistsError]: pass @abstractmethod - async def create_reset_password_token( - self, user_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError]: + async def sign_in( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError,]: pass @abstractmethod - async def reset_password_using_token( + async def verify_credentials( self, - token: str, - new_password: str, + email: str, + password: str, tenant_id: str, user_context: Dict[str, Any], - ) -> Union[ - ResetPasswordUsingTokenOkResult, ResetPasswordUsingTokenInvalidTokenError - ]: + ) -> Union[SignInOkResult, WrongCredentialsError]: pass @abstractmethod - async def sign_in( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignInOkResult, SignInWrongCredentialsError]: + async def create_reset_password_token( + self, + user_id: str, + email: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: pass @abstractmethod - async def sign_up( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignUpOkResult, SignUpEmailAlreadyExistsError]: + async def consume_password_reset_token( + self, + token: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: pass @abstractmethod async def update_email_or_password( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], password: Union[str, None], apply_password_policy: Union[bool, None], @@ -159,9 +195,10 @@ async def update_email_or_password( user_context: Dict[str, Any], ) -> Union[ UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, ]: pass @@ -203,71 +240,70 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class PasswordResetPostOkResult(APIResponse): - status: str = "OK" +class GeneratePasswordResetTokenPostNotAllowedResponse(APIResponse): + status: str = "PASSWORD_RESET_NOT_ALLOWED" - def __init__(self, user_id: Union[str, None]): - self.user_id = user_id + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} -class PasswordResetPostInvalidTokenResponse(APIResponse): - status: str = "RESET_PASSWORD_INVALID_TOKEN_ERROR" +class PasswordResetPostOkResult(APIResponse): + status: str = "OK" + + def __init__(self, email: str, user: AccountLinkingUser): + self.email = email + self.user = user def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "email": self.email, "user": self.user.to_json()} class SignInPostOkResult(APIResponse): status: str = "OK" - def __init__(self, user: User, session: SessionContainer): + def __init__(self, user: AccountLinkingUser, session: SessionContainer): self.user = user self.session = session def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - }, + "user": self.user.to_json(), } -class SignInPostWrongCredentialsError(APIResponse): - status: str = "WRONG_CREDENTIALS_ERROR" +class SignInPostNotAllowedResponse(APIResponse): + status: str = "SIGN_IN_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} class SignUpPostOkResult(APIResponse): status: str = "OK" - def __init__(self, user: User, session: SessionContainer): + def __init__(self, user: AccountLinkingUser, session: SessionContainer): self.user = user self.session = session def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - }, - } + return {"status": self.status, "user": self.user.to_json()} -class SignUpPostEmailAlreadyExistsError(APIResponse): - status: str = "EMAIL_ALREADY_EXISTS_ERROR" +class SignUpPostNotAllowedResponse(APIResponse): + status: str = "SIGN_UP_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} class APIInterface: @@ -295,7 +331,11 @@ async def generate_password_reset_token_post( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[GeneratePasswordResetTokenPostOkResult, GeneralErrorResponse]: + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: pass @abstractmethod @@ -308,7 +348,8 @@ async def password_reset_post( user_context: Dict[str, Any], ) -> Union[ PasswordResetPostOkResult, - PasswordResetPostInvalidTokenResponse, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, GeneralErrorResponse, ]: pass @@ -318,10 +359,14 @@ async def sign_in_post( self, form_fields: List[FormField], tenant_id: str, + session: Union[SessionContainer, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignInPostOkResult, SignInPostWrongCredentialsError, GeneralErrorResponse + SignInPostOkResult, + WrongCredentialsError, + SignInPostNotAllowedResponse, + GeneralErrorResponse, ]: pass @@ -330,9 +375,13 @@ async def sign_up_post( self, form_fields: List[FormField], tenant_id: str, + session: Union[SessionContainer, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignUpPostOkResult, SignUpPostEmailAlreadyExistsError, GeneralErrorResponse + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, ]: pass diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index abe38a7c6..f6a622f76 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -15,6 +15,7 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Union +from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -23,12 +24,19 @@ EmailPasswordIngredients, EmailTemplateVars, ) -from supertokens_python.recipe_module import APIHandled, RecipeModule -from ..emailverification.interfaces import ( - UnknownUserIdError, - GetEmailForUserIdOkResult, - EmailDoesNotExistError, +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import ( + FactorIds, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetFactorsSetupForUserFromOtherRecipesFunc, ) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.types import AccountLinkingUser, RecipeUserId from .api.implementation import APIImplementation from .exceptions import FieldError, SuperTokensEmailPasswordError @@ -117,7 +125,145 @@ def get_emailpassword_config() -> EmailPasswordConfig: ) def callback(): - pass + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return ["emailpassword"] + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + + async def get_factors_setup_for_user( + user: AccountLinkingUser, _: Dict[str, Any] + ) -> List[str]: + for login_method in user.login_methods: + # We don't check for tenant_id here because if we find the user + # with emailpassword login_method from different tenant, then + # we assume the factor is setup for this user. And as part of factor + # completion, we associate that login_method with the session's tenant_id + if login_method.recipe_id == EmailPasswordRecipe.recipe_id: + return ["emailpassword"] + return [] + + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc( + get_factors_setup_for_user + ) + ) + + async def get_emails_for_factor( + user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ]: + # This function is called in the MFA info endpoint API. + # Based on https://github.com/supertokens/supertokens-node/pull/741#discussion_r1432749346 + + # preparing some reusable variables for the logic below... + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() + + # We order the login methods based on time_joined (oldest first) + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + # Then we take the ones that belong to this recipe + recipe_login_methods = [ + lm + for lm in ordered_login_methods + if lm.recipe_id == EmailPasswordRecipe.recipe_id + ] + + if recipe_login_methods: + # If there are login methods belonging to this recipe, the factor is set up + # In this case we only list email addresses that have a password associated with them + result = ( + # First we take the verified real emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email + and not is_fake_email(lm.email) + and lm.verified + ] + + + # Then we take the non-verified real emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email + and not is_fake_email(lm.email) + and not lm.verified + ] + + + # Lastly, fake emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email and is_fake_email(lm.email) + ] + ) + else: + # This factor hasn't been set up, we list all emails belonging to the user + if any( + lm.email and not is_fake_email(lm.email) + for lm in ordered_login_methods + ): + # If there is at least one real email address linked to the user, we only suggest real addresses + result = [ + lm.email + for lm in ordered_login_methods + if lm.email and not is_fake_email(lm.email) + ] + else: + # Else we use the fake ones + result = [ + lm.email + for lm in ordered_login_methods + if lm.email and is_fake_email(lm.email) + ] + + # Since in this case emails are not guaranteed to be unique, we de-duplicate the results, keeping the oldest one in the list. + result = list(dict.fromkeys(result)) + + # If the login_method associated with the session has an email address, we move it to the top of the list (if it's already in the list) + if ( + session_login_method.email + and session_login_method.email in result + ): + result.remove(session_login_method.email) + result.insert(0, session_login_method.email) + + # If the list is empty we generate an email address to make the flow where the user is never asked for + # an email address easier to implement. + if not result: + result.append( + f"{session_recipe_user_id}@stfakeemail.supertokens.com" + ) + + return GetEmailsForFactorOkResult( + factor_id_to_emails_map={"emailpassword": result} + ) + + mfa_instance.add_func_to_get_emails_for_factor_from_other_recipes( + GetEmailsForFactorFromOtherRecipesFunc(get_emails_for_factor) + ) + + mt_recipe = MultitenancyRecipe.get_instance_optional() + if mt_recipe is not None: + mt_recipe.all_available_first_factors.append(FactorIds.EMAILPASSWORD) PostSTInitCallbacks.add_post_init_callback(callback) @@ -266,16 +412,3 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") EmailPasswordRecipe.__instance = None - - # instance functions below............... - - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[UnknownUserIdError, GetEmailForUserIdOkResult, EmailDoesNotExistError]: - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - return GetEmailForUserIdOkResult(user_info.email) - - return UnknownUserIdError() diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 3e59fd309..9d093f06d 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -14,30 +14,39 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, Union, Callable +from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import RecipeUserId from .interfaces import ( CreateResetPasswordOkResult, - CreateResetPasswordWrongUserIdError, + UnknownUserIdError, RecipeInterface, - ResetPasswordUsingTokenOkResult, - ResetPasswordUsingTokenInvalidTokenError, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, SignInOkResult, - SignInWrongCredentialsError, - SignUpEmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + WrongCredentialsError, + EmailAlreadyExistsError, SignUpOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + PasswordPolicyViolationError, + LinkingToSessionUserFailedError, ) -from .types import User from .utils import EmailPasswordConfig from .constants import FORM_FIELD_PASSWORD_ID +from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedResponse, + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, +) if TYPE_CHECKING: from supertokens_python.querier import Querier + from ...types import AccountLinkingUser class RecipeImplementation(RecipeInterface): @@ -50,116 +59,177 @@ def __init__( self.querier = querier self.get_emailpassword_config = get_emailpassword_config - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"userId": user_id} - response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), params, user_context - ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - return None - - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"email": email} - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), params, user_context + async def sign_up( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], + ) -> Union[ + SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError + ]: + response = await self.create_new_recipe_user( + email=email, + password=password, + tenant_id=tenant_id, + user_context=user_context, ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - return None + if isinstance(response, EmailAlreadyExistsError): + return response - async def create_reset_password_token( - self, user_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError]: - data = {"userId": user_id} - response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token"), - data, + updated_user = response.user + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=response.user, + recipe_user_id=response.recipe_user_id, + session=session, user_context=user_context, ) - if "status" in response and response["status"] == "OK": - return CreateResetPasswordOkResult(response["token"]) - return CreateResetPasswordWrongUserIdError() - async def reset_password_using_token( + if isinstance(link_result, LinkingToSessionUserFailedResponse): + return LinkingToSessionUserFailedError(reason=link_result.reason) + + updated_user = link_result.user + + return SignUpOkResult( + user=updated_user, + recipe_user_id=response.recipe_user_id, + ) + + async def create_new_recipe_user( self, - token: str, - new_password: str, + email: str, + password: str, tenant_id: str, user_context: Dict[str, Any], - ) -> Union[ - ResetPasswordUsingTokenOkResult, ResetPasswordUsingTokenInvalidTokenError - ]: - data = {"method": "token", "token": token, "newPassword": new_password} + ) -> Union[SignUpOkResult, EmailAlreadyExistsError]: response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset"), - data, + NormalisedURLPath(f"{tenant_id}/recipe/signup"), + { + "email": email, + "password": password, + }, user_context=user_context, ) - if "status" not in response or response["status"] != "OK": - return ResetPasswordUsingTokenInvalidTokenError() - user_id = None - if "userId" in response: - user_id = response["userId"] - return ResetPasswordUsingTokenOkResult(user_id) + if response["status"] == "OK": + return SignUpOkResult( + user=AccountLinkingUser.from_json(response["user"]), + recipe_user_id=RecipeUserId(response["recipeUserId"]), + ) + return EmailAlreadyExistsError() async def sign_in( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignInOkResult, SignInWrongCredentialsError]: - data = {"password": password, "email": email} + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: + response = await self.verify_credentials( + email, password, tenant_id, user_context + ) + + if isinstance(response, SignInOkResult): + login_method = next( + ( + lm + for lm in response.user.login_methods + if lm.recipe_user_id.get_as_string() + == response.recipe_user_id.get_as_string() + ), + None, + ) + + assert login_method is not None + + if not login_method.verified: + await AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=response.user, + recipe_user_id=response.recipe_user_id, + user_context=user_context, + ) + + # We do this to get the updated user (in case the above function updated the verification status) + updated_user = await get_user( + response.recipe_user_id.get_as_string(), user_context + ) + assert updated_user is not None + response.user = updated_user + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=response.user, + recipe_user_id=response.recipe_user_id, + session=session, + user_context=user_context, + ) + + if isinstance(link_result, LinkingToSessionUserFailedResponse): + return LinkingToSessionUserFailedError(reason=link_result.reason) + + response.user = link_result.user + + return response + + async def verify_credentials( + self, + email: str, + password: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError]: response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signin"), - data, + { + "email": email, + "password": password, + }, user_context=user_context, ) - if "status" in response and response["status"] == "OK": + + if response["status"] == "OK": return SignInOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) + user=AccountLinkingUser.from_json(response["user"]), + recipe_user_id=RecipeUserId(response["recipeUserId"]), ) - return SignInWrongCredentialsError() - async def sign_up( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignUpOkResult, SignUpEmailAlreadyExistsError]: - data = {"password": password, "email": email} + return WrongCredentialsError() + + async def create_reset_password_token( + self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + ) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: + data = {"userId": user_id, "email": email} response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signup"), + NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token"), data, user_context=user_context, ) if "status" in response and response["status"] == "OK": - return SignUpOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - ) - return SignUpEmailAlreadyExistsError() + return CreateResetPasswordOkResult(response["token"]) + return UnknownUserIdError() + + async def consume_password_reset_token( + self, + token: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: + data = {"token": token} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token/consume"), + data, + user_context=user_context, + ) + if "status" not in response or response["status"] != "OK": + return PasswordResetTokenInvalidError() + return ConsumePasswordResetTokenOkResult(response["email"], response["userId"]) async def update_email_or_password( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], password: Union[str, None], apply_password_policy: Union[bool, None], @@ -167,34 +237,79 @@ async def update_email_or_password( user_context: Dict[str, Any], ) -> Union[ UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + EmailAlreadyExistsError, + UnknownUserIdError, + PasswordPolicyViolationError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, ]: - data = {"userId": user_id} + account_linking = AccountLinkingRecipe.get_instance() + data = {"recipeUserId": recipe_user_id.get_as_string()} + if email is not None: - data = {"email": email, **data} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UnknownUserIdError() + + ev_instance = EmailVerificationRecipe.get_instance_optional() + is_email_verified = False + if ev_instance: + is_email_verified = ( + await ev_instance.recipe_implementation.is_email_verified( + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, + ) + ) + + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_email_verified, + new_email=email, + session=None, + user_context=user_context, + ) + if not is_email_change_allowed.allowed: + reason = ( + "New email cannot be applied to existing account because of account takeover risks." + if is_email_change_allowed.reason == "ACCOUNT_TAKEOVER_RISK" + else "New email cannot be applied to existing account because of there is another primary user with the same email address." + ) + return UpdateEmailOrPasswordEmailChangeNotAllowedError(reason) + + data["email"] = email + if password is not None: if apply_password_policy is None or apply_password_policy: form_fields = ( self.get_emailpassword_config().sign_up_feature.form_fields ) - password_field = list( - filter(lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields) - )[0] + password_field = next( + field for field in form_fields if field.id == FORM_FIELD_PASSWORD_ID + ) error = await password_field.validate( password, tenant_id_for_password_policy ) if error is not None: - return UpdateEmailOrPasswordPasswordPolicyViolationError(error) - data = {"password": password, **data} + return PasswordPolicyViolationError(error) + data["password"] = password + response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, user_context=user_context, ) - if "status" in response and response["status"] == "OK": + + if response.get("status") == "OK": + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UnknownUserIdError() + await AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) return UpdateEmailOrPasswordOkResult() - if "status" in response and response["status"] == "EMAIL_ALREADY_EXISTS_ERROR": - return UpdateEmailOrPasswordEmailAlreadyExistsError() - return UpdateEmailOrPasswordUnknownUserIdError() + elif response.get("status") == "EMAIL_ALREADY_EXISTS_ERROR": + return EmailAlreadyExistsError() + else: + return UnknownUserIdError() diff --git a/supertokens_python/recipe/emailpassword/syncio/__init__.py b/supertokens_python/recipe/emailpassword/syncio/__init__.py index 8a7cd833f..1681589cb 100644 --- a/supertokens_python/recipe/emailpassword/syncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/syncio/__init__.py @@ -12,130 +12,200 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict, Union, Optional +from typing_extensions import Literal from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.emailpassword.interfaces import ( + SignUpOkResult, + EmailAlreadyExistsError, + LinkingToSessionUserFailedError, + SignInOkResult, + WrongCredentialsError, + CreateResetPasswordOkResult, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, + UpdateEmailOrPasswordOkResult, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +) +from supertokens_python.recipe.emailpassword.types import ( + EmailTemplateVars, +) +from supertokens_python.types import RecipeUserId -from ..interfaces import SignInOkResult, SignInWrongCredentialsError -from ..types import EmailTemplateVars, User +def sign_up( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import sign_up as async_sign_up -def update_email_or_password( - user_id: str, - email: Union[str, None] = None, - password: Union[str, None] = None, - apply_password_policy: Union[bool, None] = None, - tenant_id_for_password_policy: Optional[str] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password - - return sync( - update_email_or_password( - user_id, - email, - password, - apply_password_policy, - tenant_id_for_password_policy, - user_context, - ) - ) + return sync(async_sign_up(tenant_id, email, password, session, user_context)) -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[None, User]: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id +def sign_in( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import sign_in as async_sign_in - return sync(get_user_by_id(user_id, user_context)) + return sync(async_sign_in(tenant_id, email, password, session, user_context)) -def get_user_by_email( +def verify_credentials( tenant_id: str, email: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[None, User]: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + password: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + verify_credentials as async_verify_credentials, + ) - return sync(get_user_by_email(tenant_id, email, user_context)) + return sync(async_verify_credentials(tenant_id, email, password, user_context)) def create_reset_password_token( tenant_id: str, user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_token, + create_reset_password_token as async_create_reset_password_token, ) - return sync(create_reset_password_token(tenant_id, user_id, user_context)) + return sync( + async_create_reset_password_token(tenant_id, user_id, email, user_context) + ) def reset_password_using_token( tenant_id: str, token: str, new_password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + PasswordPolicyViolationError, + PasswordResetTokenInvalidError, + UnknownUserIdError, +]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - reset_password_using_token, + reset_password_using_token as async_reset_password_using_token, ) return sync( - reset_password_using_token(tenant_id, token, new_password, user_context) + async_reset_password_using_token(tenant_id, token, new_password, user_context) ) -def sign_in( - tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[SignInOkResult, SignInWrongCredentialsError]: - from supertokens_python.recipe.emailpassword.asyncio import sign_in - - return sync(sign_in(tenant_id, email, password, user_context)) - - -def sign_up( +def consume_password_reset_token( tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import sign_up + token: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + consume_password_reset_token as async_consume_password_reset_token, + ) - return sync(sign_up(tenant_id, email, password, user_context)) + return sync(async_consume_password_reset_token(tenant_id, token, user_context)) -def send_email( - input_: EmailTemplateVars, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import send_email +def update_email_or_password( + recipe_user_id: RecipeUserId, + email: Optional[str] = None, + password: Optional[str] = None, + apply_password_policy: Optional[bool] = None, + tenant_id_for_password_policy: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + update_email_or_password as async_update_email_or_password, + ) - return sync(send_email(input_, user_context)) + return sync( + async_update_email_or_password( + recipe_user_id, + email, + password, + apply_password_policy, + tenant_id_for_password_policy, + user_context, + ) + ) def create_reset_password_link( tenant_id: str, user_id: str, + email: str, user_context: Optional[Dict[str, Any]] = None, -): +) -> Union[str, UnknownUserIdError]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_link, + create_reset_password_link as async_create_reset_password_link, ) - return sync(create_reset_password_link(tenant_id, user_id, user_context)) + return sync( + async_create_reset_password_link(tenant_id, user_id, email, user_context) + ) def send_reset_password_email( tenant_id: str, user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[Literal["UNKNOWN_USER_ID_ERROR"], Literal["OK"]]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + send_reset_password_email as async_send_reset_password_email, + ) + + return sync( + async_send_reset_password_email(tenant_id, user_id, email, user_context) + ) + + +def send_email( + input_: EmailTemplateVars, user_context: Optional[Dict[str, Any]] = None, ): + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - send_reset_password_email, + send_email as async_send_email, ) - return sync(send_reset_password_email(tenant_id, user_id, user_context)) + return sync(async_send_email(input_, user_context)) diff --git a/supertokens_python/recipe/emailpassword/types.py b/supertokens_python/recipe/emailpassword/types.py index 41752c2cd..cc8c61b62 100644 --- a/supertokens_python/recipe/emailpassword/types.py +++ b/supertokens_python/recipe/emailpassword/types.py @@ -12,38 +12,14 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations -from typing import Awaitable, Callable, List, TypeVar, Union +from typing import Awaitable, Callable, Optional, TypeVar, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryInterface, SMTPServiceInterface, ) - - -class User: - def __init__( - self, user_id: str, email: str, time_joined: int, tenant_ids: List[str] - ): - self.user_id = user_id - self.email = email - self.time_joined = time_joined - self.tenant_ids = tenant_ids - - def __eq__(self, other: object): - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - ) - - -class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users = users - self.next_pagination_token = next_pagination_token +from supertokens_python.types import RecipeUserId class ErrorFormField: @@ -89,8 +65,11 @@ def __init__( class PasswordResetEmailTemplateVarsUser: - def __init__(self, user_id: str, email: str): + def __init__( + self, user_id: str, recipe_user_id: Optional[RecipeUserId], email: str + ): self.id = user_id + self.recipe_user_id = recipe_user_id self.email = email From 13fa57c42e44294df49a2df932207eae893ef1ee Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 29 Aug 2024 13:46:00 +0530 Subject: [PATCH 027/126] changes ep recipe --- supertokens_python/asyncio/__init__.py | 18 ++++- supertokens_python/auth_utils.py | 3 - .../api/userdetails/user_password_put.py | 45 +++--------- .../dashboard/api/userdetails/user_put.py | 8 +-- supertokens_python/recipe/dashboard/utils.py | 2 - .../emailpassword/api/implementation.py | 8 ++- .../recipe/emailpassword/asyncio/__init__.py | 5 +- .../emailpassword/recipe_implementation.py | 2 +- supertokens_python/syncio/__init__.py | 19 ++++- tests/auth-react/django3x/mysite/utils.py | 6 +- tests/auth-react/django3x/polls/views.py | 19 ++--- tests/auth-react/fastapi-server/app.py | 16 +++-- tests/auth-react/flask-server/app.py | 24 ++++--- tests/emailpassword/test_emaildelivery.py | 14 ++-- tests/emailpassword/test_multitenancy.py | 52 ++++++++------ tests/emailpassword/test_passwordreset.py | 13 ++-- .../test_updateemailorpassword.py | 5 +- tests/multitenancy/test_tenants_crud.py | 9 +-- .../test_supertokens_functions.py | 2 +- tests/test-server/emailpassword.py | 24 ++++--- tests/test_querier.py | 72 +++++++++---------- tests/test_user_context.py | 57 +++++++++++---- tests/useridmapping/create_user_id_mapping.py | 8 +-- tests/useridmapping/delete_user_id_mapping.py | 4 +- tests/useridmapping/get_user_id_mapping.py | 2 +- tests/useridmapping/recipe_tests.py | 32 ++++----- tests/userroles/test_multitenancy.py | 2 +- 27 files changed, 265 insertions(+), 206 deletions(-) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index e0284ab54..f92abece2 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -26,7 +26,7 @@ ) from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import AccountInfo, AccountLinkingUser async def get_users_oldest_first( @@ -155,3 +155,19 @@ async def update_or_delete_user_id_mapping_info( return await Supertokens.get_instance().update_or_delete_user_id_mapping_info( user_id, user_id_type, external_user_id_info, user_context ) + + +async def list_users_by_account_info( + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool = False, + user_context: Optional[Dict[str, Any]] = None, +) -> List[AccountLinkingUser]: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id, + account_info, + do_union_of_account_info, + user_context, + ) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index d49357b0f..40a875540 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -40,9 +40,6 @@ from supertokens_python.utils import log_debug_message from .asyncio import get_user -from typing import Dict, Union -from supertokens_python.utils import log_debug_message - class OkResponse: status: Literal["OK"] diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py index fd9dcee3e..bc72c61bb 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py @@ -1,24 +1,18 @@ -from typing import Any, Callable, Dict, List, Union +from typing import Any, Dict, List, Union from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailpassword import EmailPasswordRecipe -from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_token as ep_create_reset_password_token, -) -from supertokens_python.recipe.emailpassword.asyncio import ( - reset_password_using_token as ep_reset_password_using_token, -) from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_PASSWORD_ID from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordOkResult, - CreateResetPasswordWrongUserIdError, - ResetPasswordUsingTokenInvalidTokenError, - ResetPasswordUsingTokenOkResult, + UnknownUserIdError, + PasswordResetTokenInvalidError, +) +from supertokens_python.recipe.emailpassword.asyncio import ( + create_reset_password_token, + reset_password_using_token, ) from supertokens_python.recipe.emailpassword.types import NormalisedFormField -from supertokens_python.utils import Awaitable - from ...interfaces import ( APIInterface, APIOptions, @@ -45,21 +39,6 @@ async def handle_user_password_put( async def reset_password( form_fields: List[NormalisedFormField], - create_reset_password_token: Callable[ - [str, str, Dict[str, Any]], - Awaitable[ - Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError] - ], - ], - reset_password_using_token: Callable[ - [str, str, str, Dict[str, Any]], - Awaitable[ - Union[ - ResetPasswordUsingTokenOkResult, - ResetPasswordUsingTokenInvalidTokenError, - ] - ], - ], ) -> Union[ UserPasswordPutAPIResponse, UserPasswordPutAPIInvalidPasswordErrorResponse ]: @@ -77,10 +56,10 @@ async def reset_password( ) password_reset_token = await create_reset_password_token( - tenant_id, user_id, user_context + tenant_id, user_id, "", user_context ) - if isinstance(password_reset_token, CreateResetPasswordWrongUserIdError): + if isinstance(password_reset_token, UnknownUserIdError): # Techincally it can but its an edge case so we assume that it wont # UNKNOWN_USER_ID_ERROR raise Exception("Should never come here") @@ -89,9 +68,7 @@ async def reset_password( tenant_id, password_reset_token.token, new_password, user_context ) - if isinstance( - password_reset_response, ResetPasswordUsingTokenInvalidTokenError - ): + if isinstance(password_reset_response, PasswordResetTokenInvalidError): # RESET_PASSWORD_INVALID_TOKEN_ERROR raise Exception("Should not come here") @@ -99,6 +76,4 @@ async def reset_password( return await reset_password( EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields, - ep_create_reset_password_token, - ep_reset_password_using_token, ) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index 4f72e8cc7..9770bd313 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -11,7 +11,7 @@ ) from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_EMAIL_ID from supertokens_python.recipe.emailpassword.interfaces import ( - UpdateEmailOrPasswordEmailAlreadyExistsError, + EmailAlreadyExistsError, ) from supertokens_python.recipe.passwordless import PasswordlessRecipe from supertokens_python.recipe.passwordless.asyncio import ( @@ -77,12 +77,10 @@ async def update_email_for_recipe_id( return UserPutAPIInvalidEmailErrorResponse(validation_error) email_update_response = await ep_update_email_or_password( - user_id, email, user_context=user_context + RecipeUserId(user_id), email, user_context=user_context ) - if isinstance( - email_update_response, UpdateEmailOrPasswordEmailAlreadyExistsError - ): + if isinstance(email_update_response, EmailAlreadyExistsError): return UserPutAPIEmailAlreadyExistsErrorResponse() return UserPutAPIOkResponse() diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index db2206ed0..de8f138d1 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -187,12 +187,10 @@ def is_valid_recipe_id(recipe_id: str) -> bool: if TYPE_CHECKING: - from supertokens_python.recipe.emailpassword.types import User as EmailPasswordUser from supertokens_python.recipe.passwordless.types import User as PasswordlessUser from supertokens_python.recipe.thirdparty.types import User as ThirdPartyUser GetUserResult = Union[ - EmailPasswordUser, ThirdPartyUser, PasswordlessUser, None, diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index 5498079c7..c0568e0fb 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -323,8 +323,12 @@ async def do_update_password_and_verify_email_and_try_link_if_not_primary( ) ) - if isinstance(update_response, EmailAlreadyExistsError) or isinstance( - update_response, UpdateEmailOrPasswordEmailChangeNotAllowedError + if isinstance( + update_response, + ( + EmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + ), ): raise Exception("Should never happen") if isinstance(update_response, UnknownUserIdError): diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index ec01bcf0a..f3015efd2 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -123,8 +123,9 @@ async def reset_password_using_token( user_context=user_context, ) - if isinstance(result, EmailAlreadyExistsError) or isinstance( - result, UpdateEmailOrPasswordEmailChangeNotAllowedError + if isinstance( + result, + (EmailAlreadyExistsError, UpdateEmailOrPasswordEmailChangeNotAllowedError), ): raise Exception("Should never happen") diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 9d093f06d..d54a8bed9 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -43,10 +43,10 @@ LinkingToSessionUserFailedResponse, link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, ) +from ...types import AccountLinkingUser if TYPE_CHECKING: from supertokens_python.querier import Querier - from ...types import AccountLinkingUser class RecipeImplementation(RecipeInterface): diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index 85f9d6c89..e7f49520b 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -25,7 +25,7 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import AccountInfo, AccountLinkingUser def get_users_oldest_first( @@ -160,3 +160,20 @@ def update_or_delete_user_id_mapping_info( user_id, user_id_type, external_user_id_info, user_context ) ) + + +def list_users_by_account_info( + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool = False, + user_context: Optional[Dict[str, Any]] = None, +) -> List[AccountLinkingUser]: + from supertokens_python.asyncio import ( + list_users_by_account_info as async_list_users_by_account_info, + ) + + return sync( + async_list_users_by_account_info( + tenant_id, account_info, do_union_of_account_info, user_context + ) + ) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 39a149043..cd37164c7 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -392,6 +392,7 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -405,12 +406,13 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -420,7 +422,7 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) original_implementation.email_exists_get = email_exists_get diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index 81c704a14..550c43ce2 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -24,6 +24,7 @@ from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.userroles import UserRoleClaim, PermissionClaim +from supertokens_python.types import AccountInfo mode = os.environ.get("APP_MODE", "asgi") @@ -94,14 +95,16 @@ async def check_role_api(): # type: ignore return JsonResponse({"status": "OK"}) async def delete_user(request: HttpRequest): - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + from supertokens_python.asyncio import list_users_by_account_info from supertokens_python.asyncio import delete_user body = json.loads(request.body) - user = await get_user_by_email("public", body["email"]) - if user is None: + user = await list_users_by_account_info( + "public", AccountInfo(email=body["email"]) + ) + if len(user) == 0: raise Exception("Should not come here") - await delete_user(user.user_id) + await delete_user(user[0].id) return JsonResponse({"status": "OK"}) else: @@ -144,14 +147,14 @@ def sync_unverify_email_api(request: HttpRequest): return JsonResponse({"status": "OK"}) def sync_delete_user(request: HttpRequest): - from supertokens_python.recipe.emailpassword.syncio import get_user_by_email + from supertokens_python.syncio import list_users_by_account_info from supertokens_python.syncio import delete_user body = json.loads(request.body) - user = get_user_by_email("public", body["email"]) - if user is None: + user = list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - delete_user(user.user_id) + delete_user(user[0].id) return JsonResponse({"status": "OK"}) @verify_session(override_global_claim_validators=override_global_claim_validators) diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 81932186d..4513b77c9 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -108,8 +108,8 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import GeneralErrorResponse -from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email +from supertokens_python.types import AccountInfo, GeneralErrorResponse +from supertokens_python.asyncio import list_users_by_account_info from supertokens_python.asyncio import delete_user load_dotenv() @@ -447,6 +447,7 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -460,12 +461,13 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -475,7 +477,7 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) original_implementation.email_exists_get = email_exists_get @@ -800,10 +802,10 @@ async def set_role_api( @app.post("/deleteUser") async def delete_user_api(request: Request): body = await request.json() - user = await get_user_by_email("public", body["email"]) - if user is None: + user = await list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - await delete_user(user.user_id) + await delete_user(user[0].id) return JSONResponse({"status": "OK"}) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 9b7cfa3ee..9b7fc7cca 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -48,7 +48,6 @@ from supertokens_python.recipe.emailpassword.types import ( FormField, InputFormField, - User, ) from supertokens_python.recipe.emailverification import ( EmailVerificationClaim, @@ -103,9 +102,12 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import GeneralErrorResponse -from supertokens_python.recipe.emailpassword.syncio import get_user_by_email -from supertokens_python.syncio import delete_user +from supertokens_python.types import ( + AccountInfo, + AccountLinkingUser, + GeneralErrorResponse, +) +from supertokens_python.syncio import delete_user, list_users_by_account_info load_dotenv() @@ -198,7 +200,7 @@ async def send_email( async def create_and_send_custom_email( - _: User, url_with_token: str, __: Dict[str, Any] + _: AccountLinkingUser, url_with_token: str, __: Dict[str, Any] ) -> None: global latest_url_with_token latest_url_with_token = url_with_token @@ -396,6 +398,7 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -409,12 +412,13 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -424,7 +428,7 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) original_implementation.email_exists_get = email_exists_get @@ -816,10 +820,10 @@ def verify_email_api(): @app.route("/deleteUser", methods=["POST"]) # type: ignore def delete_user_api(): body: Dict[str, Any] = request.get_json() # type: ignore - user = get_user_by_email("public", body["email"]) - if user is None: + user = list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - delete_user(user.user_id) + delete_user(user[0].id) return jsonify({"status": "OK"}) diff --git a/tests/emailpassword/test_emaildelivery.py b/tests/emailpassword/test_emaildelivery.py index 7530bdbcc..d24e88e5f 100644 --- a/tests/emailpassword/test_emaildelivery.py +++ b/tests/emailpassword/test_emaildelivery.py @@ -61,10 +61,6 @@ from supertokens_python.recipe.emailpassword.asyncio import ( send_reset_password_email, ) -from supertokens_python.recipe.emailpassword.interfaces import ( - SendResetPasswordEmailOkResult, - SendResetPasswordEmailUnknownUserIdError, -) respx_mock = respx.MockRouter @@ -1085,8 +1081,8 @@ async def send_email( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - resp = await send_reset_password_email("public", user_info["id"]) - assert isinstance(resp, SendResetPasswordEmailOkResult) + resp = await send_reset_password_email("public", user_info["id"], "") + assert resp == "OK" assert reset_url == "http://supertokens.io/auth/reset-password" assert token_info is not None and "token=" in token_info @@ -1122,9 +1118,9 @@ async def test_send_reset_password_email_invalid_input( dict_response = json.loads(response_1.text) user_info = dict_response["user"] - link = await send_reset_password_email("public", "invalidUserId") - assert isinstance(link, SendResetPasswordEmailUnknownUserIdError) + link = await send_reset_password_email("public", "invalidUserId", "") + assert link == "UNKNOWN_USER_ID_ERROR" with raises(Exception) as err: - await send_reset_password_email("invalidTenantId", user_info["id"]) + await send_reset_password_email("invalidTenantId", user_info["id"], "") assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_multitenancy.py b/tests/emailpassword/test_multitenancy.py index b57683bc0..12d461b0a 100644 --- a/tests/emailpassword/test_multitenancy.py +++ b/tests/emailpassword/test_multitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.recipe import session, userroles, emailpassword, multitenancy from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,8 +21,6 @@ from supertokens_python.recipe.emailpassword.asyncio import ( sign_up, sign_in, - get_user_by_id, - get_user_by_email, create_reset_password_token, reset_password_using_token, ) @@ -31,6 +30,7 @@ CreateResetPasswordOkResult, ) from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -74,9 +74,9 @@ async def test_multitenancy_in_emailpassword(): assert isinstance(user2, SignUpOkResult) assert isinstance(user3, SignUpOkResult) - assert user1.user.user_id != user2.user.user_id - assert user2.user.user_id != user3.user.user_id - assert user3.user.user_id != user1.user.user_id + assert user1.user.id != user2.user.id + assert user2.user.id != user3.user.id + assert user3.user.id != user1.user.id assert user1.user.tenant_ids == ["t1"] assert user2.user.tenant_ids == ["t2"] @@ -91,32 +91,44 @@ async def test_multitenancy_in_emailpassword(): assert isinstance(ep_user2, SignInOkResult) assert isinstance(ep_user3, SignInOkResult) - assert ep_user1.user.user_id == user1.user.user_id - assert ep_user2.user.user_id == user2.user.user_id - assert ep_user3.user.user_id == user3.user.user_id + assert ep_user1.user.id == user1.user.id + assert ep_user2.user.id == user2.user.id + assert ep_user3.user.id == user3.user.id # get user by id: - g_user1 = await get_user_by_id(user1.user.user_id) - g_user2 = await get_user_by_id(user2.user.user_id) - g_user3 = await get_user_by_id(user3.user.user_id) + g_user1 = await get_user(user1.user.id) + g_user2 = await get_user(user2.user.id) + g_user3 = await get_user(user3.user.id) assert g_user1 == user1.user assert g_user2 == user2.user assert g_user3 == user3.user # get user by email: - by_email_user1 = await get_user_by_email("t1", "test@example.com") - by_email_user2 = await get_user_by_email("t2", "test@example.com") - by_email_user3 = await get_user_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) - assert by_email_user1 == user1.user - assert by_email_user2 == user2.user - assert by_email_user3 == user3.user + assert by_email_user1[0] == user1.user + assert by_email_user2[0] == user2.user + assert by_email_user3[0] == user3.user # create password reset token: - pless_reset_link1 = await create_reset_password_token("t1", user1.user.user_id) - pless_reset_link2 = await create_reset_password_token("t2", user2.user.user_id) - pless_reset_link3 = await create_reset_password_token("t3", user3.user.user_id) + pless_reset_link1 = await create_reset_password_token( + "t1", user1.user.id, user1.user.emails[0] + ) + pless_reset_link2 = await create_reset_password_token( + "t2", user2.user.id, user2.user.emails[0] + ) + pless_reset_link3 = await create_reset_password_token( + "t3", user3.user.id, user3.user.emails[0] + ) assert isinstance(pless_reset_link1, CreateResetPasswordOkResult) assert isinstance(pless_reset_link2, CreateResetPasswordOkResult) diff --git a/tests/emailpassword/test_passwordreset.py b/tests/emailpassword/test_passwordreset.py index ff00f46d8..8be716f3b 100644 --- a/tests/emailpassword/test_passwordreset.py +++ b/tests/emailpassword/test_passwordreset.py @@ -27,7 +27,7 @@ from supertokens_python.recipe import emailpassword, session from supertokens_python.recipe.emailpassword.asyncio import create_reset_password_link from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordLinkUnknownUserIdError, + UnknownUserIdError, ) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import ( @@ -383,19 +383,20 @@ async def test_create_reset_password_link( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - link = await create_reset_password_link("public", user_info["id"]) - url = urlparse(link.link) # type: ignore + link = await create_reset_password_link("public", user_info["id"], "") + assert isinstance(link, str) + url = urlparse(link) queries = url.query.strip("&").split("&") assert url.path == "/auth/reset-password" assert "token=" in queries[0] assert "tenantId=public" in queries assert "rid=emailpassword" not in queries - link = await create_reset_password_link("public", "invalidUserId") - assert isinstance(link, CreateResetPasswordLinkUnknownUserIdError) + link = await create_reset_password_link("public", "invalidUserId", "") + assert isinstance(link, UnknownUserIdError) with raises(Exception) as err: - await create_reset_password_link("invalidTenantId", user_info["id"]) + await create_reset_password_link("invalidTenantId", user_info["id"], "") assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_updateemailorpassword.py b/tests/emailpassword/test_updateemailorpassword.py index 586e7d5c4..71a53112b 100644 --- a/tests/emailpassword/test_updateemailorpassword.py +++ b/tests/emailpassword/test_updateemailorpassword.py @@ -15,6 +15,7 @@ from typing import Any from fastapi import FastAPI +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -73,7 +74,7 @@ async def test_update_email_or_password_with_default_validator( user_id = dict_response["user"]["id"] r = await update_email_or_password( - user_id=user_id, + recipe_user_id=RecipeUserId(user_id), email=None, password="test", user_context={}, @@ -126,7 +127,7 @@ async def validate_pass(value: Any, _tenant_id: str): user_id = dict_response["user"]["id"] r = await update_email_or_password( - user_id=user_id, + recipe_user_id=RecipeUserId(user_id), email=None, password="te", user_context={}, diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 627737649..53d75e7f9 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -17,6 +17,7 @@ from typing import Any, Dict from supertokens_python import init +from supertokens_python.asyncio import get_user from supertokens_python.framework.fastapi import get_middleware from supertokens_python.recipe import emailpassword, multitenancy, session from supertokens_python.types import RecipeUserId @@ -43,7 +44,7 @@ associate_user_to_tenant, dissociate_user_from_tenant, ) -from supertokens_python.recipe.emailpassword.asyncio import sign_up, get_user_by_id +from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.recipe.thirdparty.provider import ( @@ -298,13 +299,13 @@ async def test_user_association_and_disassociation_with_tenants(): signup_response = await sign_up("public", "test@example.com", "password1") assert isinstance(signup_response, SignUpOkResult) - user_id = signup_response.user.user_id + user_id = signup_response.user.id await associate_user_to_tenant("t1", RecipeUserId(user_id)) await associate_user_to_tenant("t2", RecipeUserId(user_id)) await associate_user_to_tenant("t3", RecipeUserId(user_id)) - user = await get_user_by_id(user_id) + user = await get_user(user_id) assert user is not None assert len(user.tenant_ids) == 4 # public + 3 tenants @@ -312,6 +313,6 @@ async def test_user_association_and_disassociation_with_tenants(): await dissociate_user_from_tenant("t2", RecipeUserId(user_id)) await dissociate_user_from_tenant("t3", RecipeUserId(user_id)) - user = await get_user_by_id(user_id) + user = await get_user(user_id) assert user is not None assert len(user.tenant_ids) == 1 # public only diff --git a/tests/supertokens_python/test_supertokens_functions.py b/tests/supertokens_python/test_supertokens_functions.py index fdd8374bf..721db4a02 100644 --- a/tests/supertokens_python/test_supertokens_functions.py +++ b/tests/supertokens_python/test_supertokens_functions.py @@ -61,7 +61,7 @@ async def test_supertokens_functions(): for e in emails: signup_resp = await ep_asyncio.sign_up("public", e, "secret_pass") assert isinstance(signup_resp, SignUpOkResult) - user_ids.append(signup_resp.user.user_id) + user_ids.append(signup_resp.user.id) # Get user count assert await st_asyncio.get_user_count() == len(emails) diff --git a/tests/test-server/emailpassword.py b/tests/test-server/emailpassword.py index efde717b5..3b6f73b4b 100644 --- a/tests/test-server/emailpassword.py +++ b/tests/test-server/emailpassword.py @@ -1,11 +1,11 @@ from flask import Flask, request, jsonify from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordLinkOkResult, + EmailAlreadyExistsError, SignInOkResult, SignUpOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordUnknownUserIdError, ) import supertokens_python.recipe.emailpassword.syncio as emailpassword @@ -28,8 +28,8 @@ def emailpassword_signup(): # type: ignore { "status": "OK", "user": { - "id": response.user.user_id, - "email": response.user.email, + "id": response.user.id, + "email": response.user.emails[0], "timeJoined": response.user.time_joined, "tenantIds": response.user.tenant_ids, }, @@ -56,8 +56,8 @@ def emailpassword_signin(): # type: ignore { "status": "OK", "user": { - "id": response.user.user_id, - "email": response.user.email, + "id": response.user.id, + "email": response.user.emails[0], "timeJoined": response.user.time_joined, "tenantIds": response.user.tenant_ids, }, @@ -80,8 +80,8 @@ def emailpassword_create_reset_password_link(): # type: ignore tenant_id, user_id, user_context ) - if isinstance(response, CreateResetPasswordLinkOkResult): - return jsonify({"status": "OK", "link": response.link}) + if isinstance(response, str): + return jsonify({"status": "OK", "link": response}) else: return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) @@ -109,10 +109,12 @@ def emailpassword_update_email_or_password(): # type: ignore if isinstance(response, UpdateEmailOrPasswordOkResult): return jsonify({"status": "OK"}) - elif isinstance(response, UpdateEmailOrPasswordUnknownUserIdError): + elif isinstance(response, UnknownUserIdError): return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) - elif isinstance(response, UpdateEmailOrPasswordEmailAlreadyExistsError): + elif isinstance(response, EmailAlreadyExistsError): return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return jsonify({"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR"}) else: return jsonify( { diff --git a/tests/test_querier.py b/tests/test_querier.py index 2b4746a14..b6722a019 100644 --- a/tests/test_querier.py +++ b/tests/test_querier.py @@ -20,9 +20,9 @@ thirdparty, ) from supertokens_python import InputAppInfo -from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id, sign_up +from supertokens_python.recipe.emailpassword.asyncio import get_user, sign_up from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_id as tp_get_user_by_id, + get_user_by_id as tp_get_user, ) import asyncio import respx @@ -236,29 +236,29 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await tp_get_user_by_id("random", user_context) + user = await tp_get_user("random", user_context) assert user is None assert called_core called_core = False - user = await tp_get_user_by_id("random", user_context) + user = await tp_get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core @@ -299,22 +299,22 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core @@ -357,14 +357,14 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -407,20 +407,20 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core called_core = False - user = await tp_get_user_by_id("random", user_context) + user = await tp_get_user("random", user_context) assert user is None assert called_core @@ -461,7 +461,7 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -470,7 +470,7 @@ def intercept( called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -513,33 +513,33 @@ def intercept( user_context: Dict[str, Any] = {"_default": {"keep_cache_alive": True}} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert not called_core @@ -583,33 +583,33 @@ def intercept( user_context: Dict[str, Any] = {"_default": {"keep_cache_alive": False}} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core @@ -653,33 +653,33 @@ def intercept( user_context: Dict[str, Any] = {} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 55bd67ad3..e5089c72c 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -74,12 +74,13 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], api_options: APIOptions, user_context: Dict[str, Any], ): user_context = {"preSignInPOST": True} response = await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) if ( "preSignInPOST" in user_context @@ -100,20 +101,32 @@ def functions_override_email_password(param: RecipeInterface): og_sign_up = param.sign_up async def sign_up_( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + user_context: Dict[str, Any], ): if "manualCall" in user_context: global signUpContextWorks signUpContextWorks = True - response = await og_sign_up(email, password, tenant_id, user_context) + response = await og_sign_up( + email, password, tenant_id, session, user_context + ) return response async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + user_context: Dict[str, Any], ): if "preSignInPOST" in user_context: user_context["preSignIn"] = True - response = await og_sign_in(email, password, tenant_id, user_context) + response = await og_sign_in( + email, password, tenant_id, session, user_context + ) if "preSignInPOST" in user_context and "preSignIn" in user_context: user_context["postSignIn"] = True return response @@ -185,7 +198,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 @@ -207,6 +222,7 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -216,7 +232,7 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) param.sign_in_post = sign_in_post @@ -226,14 +242,18 @@ def functions_override_email_password(param: RecipeInterface): og_sign_in = param.sign_in async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + user_context: Dict[str, Any], ): req = user_context.get("_default", {}).get("request") if req: nonlocal signin_context_works signin_context_works = True - return await og_sign_in(email, password, tenant_id, user_context) + return await og_sign_in(email, password, tenant_id, session, user_context) param.sign_in = sign_in return param @@ -293,7 +313,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 @@ -320,6 +342,7 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -331,7 +354,7 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, tenant_id, session, api_options, user_context ) param.sign_in_post = sign_in_post @@ -341,7 +364,11 @@ def functions_override_email_password(param: RecipeInterface): og_sign_in = param.sign_in async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + user_context: Dict[str, Any], ): req = get_request_from_user_context(user_context) if req: @@ -358,7 +385,7 @@ async def sign_in( user_context["_default"]["request"] = orginal_request - return await og_sign_in(email, password, tenant_id, user_context) + return await og_sign_in(email, password, tenant_id, session, user_context) param.sign_in = sign_in return param @@ -420,7 +447,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 diff --git a/tests/useridmapping/create_user_id_mapping.py b/tests/useridmapping/create_user_id_mapping.py index a872d413a..1cff74892 100644 --- a/tests/useridmapping/create_user_id_mapping.py +++ b/tests/useridmapping/create_user_id_mapping.py @@ -55,7 +55,7 @@ async def test_create_user_id_mapping(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_user_info = "externalIdInfo" @@ -87,7 +87,7 @@ async def test_create_user_id_mapping_without_and_with_force(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" # Add metadata to the user: @@ -146,7 +146,7 @@ async def create_user_id_mapping_when_mapping_already_exists(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" # Create User ID Mapping: @@ -169,7 +169,7 @@ async def create_user_id_mapping_when_mapping_already_exists(): # Try creating a duplicate mapping where external_user_id exists and but supertokens_user_id doesn't (new) sign_up_res = await sign_up("public", "foo@bar.com", "baz") assert isinstance(sign_up_res, SignUpOkResult) - new_supertokens_user_id = sign_up_res.user.user_id + new_supertokens_user_id = sign_up_res.user.id res = await create_user_id_mapping(new_supertokens_user_id, external_user_id) assert isinstance(res, UserIdMappingAlreadyExistsError) diff --git a/tests/useridmapping/delete_user_id_mapping.py b/tests/useridmapping/delete_user_id_mapping.py index 0915d6d33..ebc3af883 100644 --- a/tests/useridmapping/delete_user_id_mapping.py +++ b/tests/useridmapping/delete_user_id_mapping.py @@ -66,7 +66,7 @@ async def test_delete_user_id_mapping(user_type: USER_TYPE): sign_up_res = await sign_up("public", "test@example.com", "password") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_id_info = "externalIdInfo" @@ -104,7 +104,7 @@ async def test_delete_user_id_mapping_without_and_with_force(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_user_info = "externalIdInfo" diff --git a/tests/useridmapping/get_user_id_mapping.py b/tests/useridmapping/get_user_id_mapping.py index 517d076f4..030f1824f 100644 --- a/tests/useridmapping/get_user_id_mapping.py +++ b/tests/useridmapping/get_user_id_mapping.py @@ -62,7 +62,7 @@ async def test_get_user_id_mapping(use_external_id_info: bool): sign_up_res = await sign_up("public", "test@example.com", "password") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_id_info = "externalIdInfo" if use_external_id_info else None diff --git a/tests/useridmapping/recipe_tests.py b/tests/useridmapping/recipe_tests.py index 42e2d1c94..cf4011a2d 100644 --- a/tests/useridmapping/recipe_tests.py +++ b/tests/useridmapping/recipe_tests.py @@ -21,10 +21,11 @@ from supertokens_python.querier import Querier from supertokens_python.recipe.emailpassword.interfaces import ( SignUpOkResult, - ResetPasswordUsingTokenOkResult, SignInOkResult, CreateResetPasswordOkResult, + UpdateEmailOrPasswordOkResult, ) +from supertokens_python.types import AccountInfo, RecipeUserId from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from .utils import st_config @@ -55,23 +56,23 @@ async def ep_get_new_user_id(email: str) -> str: sign_up_res = await sign_up("public", email, "password") assert isinstance(sign_up_res, SignUpOkResult) - return sign_up_res.user.user_id + return sign_up_res.user.id async def ep_get_existing_user_id(user_id: str) -> str: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id + from supertokens_python.asyncio import get_user - res = await get_user_by_id(user_id) + res = await get_user(user_id) assert res is not None - return res.user_id + return res.id async def ep_get_existing_user_by_email(email: str) -> str: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + from supertokens_python.asyncio import list_users_by_account_info - res = await get_user_by_email("public", email) - assert res is not None - return res.user_id + res = await list_users_by_account_info("public", AccountInfo(email=email)) + assert len(res) == 1 + return res[0].id async def ep_get_existing_user_by_signin(email: str) -> str: @@ -79,7 +80,7 @@ async def ep_get_existing_user_by_signin(email: str) -> str: res = await sign_in("public", email, "password") assert isinstance(res, SignInOkResult) - return res.user.user_id + return res.user.id async def ep_get_existing_user_after_reset_password(user_id: str) -> str: @@ -89,12 +90,11 @@ async def ep_get_existing_user_after_reset_password(user_id: str) -> str: reset_password_using_token, ) - result = await create_reset_password_token("public", user_id) + result = await create_reset_password_token("public", user_id, "") assert isinstance(result, CreateResetPasswordOkResult) res = await reset_password_using_token("public", result.token, new_password) - assert isinstance(res, ResetPasswordUsingTokenOkResult) - assert res.user_id is not None - return res.user_id + assert isinstance(res, UpdateEmailOrPasswordOkResult) + return user_id async def ep_get_existing_user_after_updating_email_and_sign_in(user_id: str) -> str: @@ -105,12 +105,12 @@ async def ep_get_existing_user_after_updating_email_and_sign_in(user_id: str) -> sign_in, ) - res = await update_email_or_password(user_id, new_email, "password") + res = await update_email_or_password(RecipeUserId(user_id), new_email, "password") assert isinstance(res, SignUpOkResult) res = await sign_in("public", new_email, "password") assert isinstance(res, SignInOkResult) - return res.user.user_id + return res.user.id @mark.parametrize("use_external_id_info", [(True,), (False,)]) diff --git a/tests/userroles/test_multitenancy.py b/tests/userroles/test_multitenancy.py index 30e557102..b67578172 100644 --- a/tests/userroles/test_multitenancy.py +++ b/tests/userroles/test_multitenancy.py @@ -63,7 +63,7 @@ async def test_multitenancy_in_user_roles(): user = await sign_up("public", "test@example.com", "password1") assert isinstance(user, SignUpOkResult) - user_id = user.user.user_id + user_id = user.user.id await associate_user_to_tenant("t1", RecipeUserId(user_id)) await associate_user_to_tenant("t2", RecipeUserId(user_id)) From 11797c6a99b099014255556e6340bbeb8a14ba0a Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 4 Sep 2024 17:39:14 +0530 Subject: [PATCH 028/126] complets passwordless changes without fixing all types --- supertokens_python/auth_utils.py | 62 +- .../recipe/emailpassword/api/signin.py | 2 + .../recipe/emailpassword/api/signup.py | 3 + .../recipe/emailpassword/asyncio/__init__.py | 2 +- .../recipe/emailpassword/interfaces.py | 22 +- .../emailpassword/recipe_implementation.py | 9 +- .../recipe/passwordless/api/consume_code.py | 11 + .../recipe/passwordless/api/create_code.py | 11 + .../recipe/passwordless/api/implementation.py | 589 +++++++++++++++--- .../recipe/passwordless/api/resend_code.py | 12 +- .../recipe/passwordless/asyncio/__init__.py | 98 ++- .../recipe/passwordless/interfaces.py | 158 +++-- .../recipe/passwordless/recipe.py | 285 ++++++++- .../passwordless/recipe_implementation.py | 574 +++++++++-------- .../recipe/passwordless/syncio/__init__.py | 137 ++-- .../recipe/passwordless/types.py | 30 +- .../recipe/passwordless/utils.py | 50 +- 17 files changed, 1445 insertions(+), 610 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 40a875540..ae9b02f77 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -41,26 +41,7 @@ from .asyncio import get_user -class OkResponse: - status: Literal["OK"] - valid_factor_ids: List[str] - is_first_factor: bool - - def __init__(self, valid_factor_ids: List[str], is_first_factor: bool): - self.status = "OK" - self.valid_factor_ids = valid_factor_ids - self.is_first_factor = is_first_factor - - -class SignUpNotAllowedResponse: - status: Literal["SIGN_UP_NOT_ALLOWED"] - - -class SignInNotAllowedResponse: - status: Literal["SIGN_IN_NOT_ALLOWED"] - - -class LinkingToSessionUserFailedResponse: +class LinkingToSessionUserFailedError: status: Literal["LINKING_TO_SESSION_USER_FAILED"] reason: Literal[ "EMAIL_VERIFICATION_REQUIRED", @@ -83,6 +64,25 @@ def __init__( self.reason = reason +class OkResponse: + status: Literal["OK"] + valid_factor_ids: List[str] + is_first_factor: bool + + def __init__(self, valid_factor_ids: List[str], is_first_factor: bool): + self.status = "OK" + self.valid_factor_ids = valid_factor_ids + self.is_first_factor = is_first_factor + + +class SignUpNotAllowedResponse: + status: Literal["SIGN_UP_NOT_ALLOWED"] + + +class SignInNotAllowedResponse: + status: Literal["SIGN_IN_NOT_ALLOWED"] + + async def pre_auth_checks( authenticating_account_info: AccountInfoWithRecipeId, authenticating_user: Union[AccountLinkingUser, None], @@ -98,7 +98,7 @@ async def pre_auth_checks( OkResponse, SignUpNotAllowedResponse, SignInNotAllowedResponse, - LinkingToSessionUserFailedResponse, + LinkingToSessionUserFailedError, ]: valid_factor_ids: List[str] = [] @@ -491,7 +491,7 @@ async def check_auth_type_and_linking_status( OkFirstFactorResponse, OkSecondFactorLinkedResponse, OkSecondFactorNotLinkedResponse, - LinkingToSessionUserFailedResponse, + LinkingToSessionUserFailedError, ]: log_debug_message("check_auth_type_and_linking_status called") session_user: Union[AccountLinkingUser, None] = None @@ -529,7 +529,7 @@ async def check_auth_type_and_linking_status( session_user_result.status == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ): - return LinkingToSessionUserFailedResponse( + return LinkingToSessionUserFailedError( reason="SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ) @@ -572,7 +572,7 @@ async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_acc recipe_user_id: RecipeUserId, session: Union[SessionContainer, None], user_context: Dict[str, Any], -) -> Union[OkResponse2, LinkingToSessionUserFailedResponse,]: +) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: log_debug_message( "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info called" ) @@ -623,7 +623,7 @@ async def retry(): OkSecondFactorNotLinkedResponse, ), ): - return LinkingToSessionUserFailedResponse(reason=auth_type_res.reason) + return LinkingToSessionUserFailedError(reason=auth_type_res.reason) if isinstance(auth_type_res, OkFirstFactorResponse): if not recipe_init_defined_should_do_automatic_account_linking(): @@ -660,7 +660,7 @@ async def retry(): linking_to_session_user_requires_verification=auth_type_res.linking_to_session_user_requires_verification, user_context=user_context, ) - if isinstance(session_linking_res, LinkingToSessionUserFailedResponse): + if isinstance(session_linking_res, LinkingToSessionUserFailedError): if session_linking_res.reason == "INPUT_USER_IS_NOT_A_PRIMARY_USER": return await retry() else: @@ -778,7 +778,7 @@ async def try_linking_by_session( authenticated_user: AccountLinkingUser, session_user: AccountLinkingUser, user_context: Dict[str, Any], -) -> Union[OkResponse2, LinkingToSessionUserFailedResponse,]: +) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: log_debug_message("tryLinkingBySession called") session_user_has_verified_account_info = any( @@ -797,7 +797,7 @@ async def try_linking_by_session( ) if not can_link_based_on_verification: - return LinkingToSessionUserFailedResponse(reason="EMAIL_VERIFICATION_REQUIRED") + return LinkingToSessionUserFailedError(reason="EMAIL_VERIFICATION_REQUIRED") link_accounts_result = ( await AccountLinkingRecipe.get_instance().recipe_implementation.link_accounts( @@ -819,21 +819,21 @@ async def try_linking_by_session( log_debug_message( "tryLinkingBySession linking to session user failed because of a race condition - input user linked to another user" ) - return LinkingToSessionUserFailedResponse( + return LinkingToSessionUserFailedError( reason="RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ) elif link_accounts_result.status == "INPUT_USER_IS_NOT_A_PRIMARY_USER": log_debug_message( "tryLinkingBySession linking to session user failed because of a race condition - INPUT_USER_IS_NOT_A_PRIMARY_USER, should retry" ) - return LinkingToSessionUserFailedResponse( + return LinkingToSessionUserFailedError( reason="INPUT_USER_IS_NOT_A_PRIMARY_USER" ) else: log_debug_message( "tryLinkingBySession linking to session user failed because of a race condition - input user has another primary user it can be linked to" ) - return LinkingToSessionUserFailedResponse( + return LinkingToSessionUserFailedError( reason="ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ) diff --git a/supertokens_python/recipe/emailpassword/api/signin.py b/supertokens_python/recipe/emailpassword/api/signin.py index 2fbfa2271..c41e02a15 100644 --- a/supertokens_python/recipe/emailpassword/api/signin.py +++ b/supertokens_python/recipe/emailpassword/api/signin.py @@ -54,6 +54,8 @@ async def handle_sign_in_api( override_global_claim_validators=lambda _, __, ___: [], user_context=user_context, ) + if session is not None: + tenant_id = session.get_tenant_id() response = await api_implementation.sign_in_post( form_fields, tenant_id, session, api_options, user_context diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index f7ad8ec26..7625dc888 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -59,6 +59,9 @@ async def handle_sign_up_api( user_context=user_context, ) + if session is not None: + tenant_id = session.get_tenant_id() + response = await api_implementation.sign_up_post( form_fields, tenant_id, session, api_options, user_context ) diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index f3015efd2..decb7efa1 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -15,6 +15,7 @@ from typing_extensions import Literal from supertokens_python import get_request_from_user_context from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.session import SessionContainer @@ -33,7 +34,6 @@ PasswordPolicyViolationError, SignUpOkResult, EmailAlreadyExistsError, - LinkingToSessionUserFailedError, SignInOkResult, WrongCredentialsError, ) diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 549666f57..42f71fba5 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -15,12 +15,16 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union -from typing_extensions import Literal +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.emailpassword.types import EmailTemplateVars from ...supertokens import AppInfo -from ...types import APIResponse, GeneralErrorResponse, RecipeUserId +from ...types import ( + APIResponse, + GeneralErrorResponse, + RecipeUserId, +) if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse @@ -46,20 +50,6 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class LinkingToSessionUserFailedError: - def __init__( - self, - reason: Literal[ - "EMAIL_VERIFICATION_REQUIRED", - "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", - "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", - "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", - "INPUT_USER_IS_NOT_A_PRIMARY_USER", - ], - ): - self.reason = reason - - class SignInOkResult: def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): self.user = user diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index d54a8bed9..ea132d6aa 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -35,12 +35,11 @@ SignUpOkResult, UpdateEmailOrPasswordOkResult, PasswordPolicyViolationError, - LinkingToSessionUserFailedError, ) from .utils import EmailPasswordConfig from .constants import FORM_FIELD_PASSWORD_ID from supertokens_python.auth_utils import ( - LinkingToSessionUserFailedResponse, + LinkingToSessionUserFailedError, link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, ) from ...types import AccountLinkingUser @@ -88,7 +87,7 @@ async def sign_up( user_context=user_context, ) - if isinstance(link_result, LinkingToSessionUserFailedResponse): + if isinstance(link_result, LinkingToSessionUserFailedError): return LinkingToSessionUserFailedError(reason=link_result.reason) updated_user = link_result.user @@ -167,8 +166,8 @@ async def sign_in( user_context=user_context, ) - if isinstance(link_result, LinkingToSessionUserFailedResponse): - return LinkingToSessionUserFailedError(reason=link_result.reason) + if isinstance(link_result, LinkingToSessionUserFailedError): + return link_result response.user = link_result.user diff --git a/supertokens_python/recipe/passwordless/api/consume_code.py b/supertokens_python/recipe/passwordless/api/consume_code.py index 442c6ba7b..4da19d9bf 100644 --- a/supertokens_python/recipe/passwordless/api/consume_code.py +++ b/supertokens_python/recipe/passwordless/api/consume_code.py @@ -14,6 +14,7 @@ from typing import Any, Dict from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.utils import send_200_response @@ -56,11 +57,21 @@ async def consume_code( pre_auth_session_id = body["preAuthSessionId"] + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.consume_code_post( pre_auth_session_id, user_input_code, device_id, link_code, + session, tenant_id, api_options, user_context, diff --git a/supertokens_python/recipe/passwordless/api/create_code.py b/supertokens_python/recipe/passwordless/api/create_code.py index c4c162f86..d69d69749 100644 --- a/supertokens_python/recipe/passwordless/api/create_code.py +++ b/supertokens_python/recipe/passwordless/api/create_code.py @@ -22,6 +22,7 @@ ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, ) +from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.types import GeneralErrorResponse from supertokens_python.utils import send_200_response @@ -109,9 +110,19 @@ async def create_code( except Exception: phone_number = phone_number.strip() + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.create_code_post( email=email, phone_number=phone_number, + session=session, tenant_id=tenant_id, api_options=api_options, user_context=user_context, diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 840da912a..5ece56109 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -11,14 +11,31 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + OkResponse, + PostAuthChecksOkResponse, + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + check_auth_type_and_linking_status, + filter_out_invalid_first_factors_or_throw_if_all_are_invalid, + get_authenticating_user_and_add_to_current_tenant_if_required, + post_auth_checks, + pre_auth_checks, +) from supertokens_python.logger import log_debug_message +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.passwordless.interfaces import ( APIInterface, APIOptions, + CheckCodeOkResult, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeOkResult, ConsumeCodePostExpiredUserInputCodeError, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostOkResult, @@ -32,6 +49,7 @@ PhoneNumberExistsGetOkResult, ResendCodePostOkResult, ResendCodePostRestartFlowError, + SignInUpPostNotAllowedResponse, ) from supertokens_python.recipe.passwordless.types import ( PasswordlessLoginSMSTemplateVars, @@ -40,33 +58,208 @@ ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, + get_enabled_pwless_factors, +) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.types import ( + AccountInfo, + AccountLinkingUser, + GeneralErrorResponse, + LoginMethod, + RecipeUserId, ) -from supertokens_python.recipe.session.asyncio import create_new_session -from supertokens_python.types import GeneralErrorResponse, RecipeUserId from ...emailverification import EmailVerificationRecipe from ...emailverification.interfaces import CreateEmailVerificationTokenOkResult +class PasswordlessUserResult: + user: AccountLinkingUser + login_method: Union[LoginMethod, None] + + def __init__( + self, user: AccountLinkingUser, login_method: Union[LoginMethod, None] + ): + self.user = user + self.login_method = login_method + + +async def get_passwordless_user_by_account_info( + tenant_id: str, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + account_info: AccountInfo, +) -> Optional[PasswordlessUserResult]: + existing_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=account_info, + do_union_of_account_info=False, + user_context=user_context, + ) + log_debug_message( + f"get_passwordless_user_by_account_info got {len(existing_users)} from core resp {account_info}" + ) + + users_with_matching_login_methods = [ + PasswordlessUserResult( + user=user, + login_method=next( + ( + lm + for lm in user.login_methods + if lm.recipe_id == "passwordless" + and ( + lm.has_same_email_as(account_info.email) + or lm.has_same_phone_number_as(account_info.phone_number) + ) + ), + None, + ), + ) + for user in existing_users + ] + users_with_matching_login_methods = [ + user_data + for user_data in users_with_matching_login_methods + if user_data.login_method is not None + ] + + log_debug_message( + f"get_passwordless_user_by_account_info {len(users_with_matching_login_methods)} has matching login methods" + ) + + if len(users_with_matching_login_methods) > 1: + raise Exception( + "This should never happen: multiple users exist matching the accountInfo in passwordless createCode" + ) + + return users_with_matching_login_methods[0] + + class APIImplementation(APIInterface): async def create_code_post( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[CreateCodePostOkResult, GeneralErrorResponse]: + ) -> Union[ + CreateCodePostOkResult, SignInUpPostNotAllowedResponse, GeneralErrorResponse + ]: + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_002)", + "LINKING_TO_SESSION_USER_FAILED": { + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_019)", + }, + } + + account_info = AccountInfo( + email=email, + phone_number=phone_number, + ) + + user_with_matching_login_method = await get_passwordless_user_by_account_info( + tenant_id, session, user_context, account_info + ) + + factor_ids = [] + if session is not None: + factor_ids = [ + FactorIds.OTP_EMAIL if email is not None else FactorIds.OTP_PHONE + ] + else: + factor_ids = get_enabled_pwless_factors(api_options.config) + if email is not None: + factor_ids = [ + f + for f in factor_ids + if f in [FactorIds.OTP_EMAIL, FactorIds.LINK_EMAIL] + ] + else: + factor_ids = [ + f + for f in factor_ids + if f in [FactorIds.OTP_PHONE, FactorIds.LINK_PHONE] + ] + + is_verified_input = True + if user_with_matching_login_method is not None: + assert user_with_matching_login_method.login_method is not None + is_verified_input = user_with_matching_login_method.login_method.verified + + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=account_info.email, + phone_number=account_info.phone_number, + ), + is_sign_up=user_with_matching_login_method is None, + authenticating_user=( + user_with_matching_login_method.user + if user_with_matching_login_method + else None + ), + is_verified=is_verified_input, + sign_in_verifies_login_method=True, + skip_session_user_update_in_core=True, + tenant_id=tenant_id, + factor_ids=factor_ids, + user_context=user_context, + session=session, + ) + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpPostNotAllowedResponse(reason=reason) + user_input_code = None if api_options.config.get_custom_user_input_code is not None: user_input_code = await api_options.config.get_custom_user_input_code( tenant_id, user_context ) + + user_input_code_input = user_input_code + if api_options.config.get_custom_user_input_code is not None: + user_input_code_input = await api_options.config.get_custom_user_input_code( + tenant_id, user_context + ) response = await api_options.recipe_implementation.create_code( - email, phone_number, user_input_code, tenant_id, user_context + email=account_info.email, + phone_number=account_info.phone_number, + user_input_code=user_input_code_input, + tenant_id=tenant_id, + user_context=user_context, + session=session, ) + magic_link = None user_input_code = None flow_type = api_options.config.flow_type + + if all( + _id.startswith("link") for _id in pre_auth_checks_result.valid_factor_ids + ): + flow_type = "MAGIC_LINK" + elif all( + _id.startswith("otp") for _id in pre_auth_checks_result.valid_factor_ids + ): + flow_type = "USER_INPUT_CODE" + else: + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + if flow_type in ("MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"): magic_link = ( api_options.app_info.get_origin( @@ -99,6 +292,7 @@ async def create_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=pre_auth_checks_result.is_first_factor, ) await api_options.email_delivery.ingredient_interface_impl.send_email( passwordless_email_delivery_input, user_context @@ -117,6 +311,7 @@ async def create_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=pre_auth_checks_result.is_first_factor, ) await api_options.sms_delivery.ingredient_interface_impl.send_sms( sms_input, user_context @@ -130,6 +325,7 @@ async def resend_code_post( self, device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -139,16 +335,48 @@ async def resend_code_post( device_info = await api_options.recipe_implementation.list_codes_by_device_id( device_id=device_id, tenant_id=tenant_id, user_context=user_context ) + if device_info is None: return ResendCodePostRestartFlowError() + if ( - isinstance(api_options.config.contact_config, ContactEmailOnlyConfig) - and device_info.email is None - ) or ( - isinstance(api_options.config.contact_config, ContactPhoneOnlyConfig) + api_options.config.contact_config.contact_method == "PHONE" and device_info.phone_number is None + ) or ( + api_options.config.contact_config.contact_method == "EMAIL" + and device_info.email is None ): return ResendCodePostRestartFlowError() + + user_with_matching_login_method = await get_passwordless_user_by_account_info( + tenant_id=tenant_id, + session=session, + user_context=user_context, + account_info=AccountInfo( + email=device_info.email, + phone_number=device_info.phone_number, + ), + ) + + auth_type_info = await check_auth_type_and_linking_status( + session=session, + account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=device_info.email, + phone_number=device_info.phone_number, + ), + input_user=( + user_with_matching_login_method.user + if user_with_matching_login_method + else None + ), + skip_session_user_update_in_core=True, + user_context=user_context, + ) + + if auth_type_info.status == "LINKING_TO_SESSION_USER_FAILED": + return ResendCodePostRestartFlowError() + number_of_tries_to_create_new_code = 0 while True: number_of_tries_to_create_new_code += 1 @@ -157,14 +385,22 @@ async def resend_code_post( user_input_code = await api_options.config.get_custom_user_input_code( tenant_id, user_context ) + user_input_code_input = user_input_code + if api_options.config.get_custom_user_input_code is not None: + user_input_code_input = ( + await api_options.config.get_custom_user_input_code( + tenant_id, user_context + ) + ) response = ( await api_options.recipe_implementation.create_new_code_for_device( device_id=device_id, - user_input_code=user_input_code, + user_input_code=user_input_code_input, tenant_id=tenant_id, user_context=user_context, ) ) + if isinstance( response, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError ): @@ -177,7 +413,30 @@ async def resend_code_post( if isinstance(response, CreateNewCodeForDeviceOkResult): magic_link = None user_input_code = None + + factor_ids = [] + if session is not None: + factor_ids = [ + ( + FactorIds.OTP_EMAIL + if device_info.email is not None + else FactorIds.OTP_PHONE + ) + ] + else: + factor_ids = get_enabled_pwless_factors(api_options.config) + factor_ids = await filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids, tenant_id, False, user_context + ) + flow_type = api_options.config.flow_type + if all(id.startswith("link") for id in factor_ids): + flow_type = "MAGIC_LINK" + elif all(id.startswith("otp") for id in factor_ids): + flow_type = "USER_INPUT_CODE" + else: + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + if flow_type in ("MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"): magic_link = ( api_options.app_info.get_origin( @@ -195,20 +454,31 @@ async def resend_code_post( if flow_type in ("USER_INPUT_CODE", "USER_INPUT_CODE_AND_MAGIC_LINK"): user_input_code = response.user_input_code - if isinstance( - api_options.config.contact_config, ContactEmailOnlyConfig - ) or ( - isinstance( - api_options.config.contact_config, ContactEmailOrPhoneConfig - ) - and device_info.email is not None + if api_options.config.contact_config.contact_method == "PHONE" or ( + api_options.config.contact_config.contact_method == "EMAIL_OR_PHONE" + and device_info.phone_number is not None ): - if device_info.email is None: - raise Exception("Should never come here") - + log_debug_message( + "Sending passwordless login SMS to %s", device_info.phone_number + ) + assert device_info.phone_number is not None + sms_input = PasswordlessLoginSMSTemplateVars( + phone_number=device_info.phone_number, + user_input_code=user_input_code, + url_with_link_code=magic_link, + code_life_time=response.code_life_time, + pre_auth_session_id=response.pre_auth_session_id, + tenant_id=tenant_id, + is_first_factor=auth_type_info.is_first_factor, + ) + await api_options.sms_delivery.ingredient_interface_impl.send_sms( + sms_input, user_context + ) + else: log_debug_message( "Sending passwordless login email to %s", device_info.email ) + assert device_info.email is not None passwordless_email_delivery_input = ( PasswordlessLoginEmailTemplateVars( email=device_info.email, @@ -217,32 +487,15 @@ async def resend_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=auth_type_info.is_first_factor, ) ) await api_options.email_delivery.ingredient_interface_impl.send_email( passwordless_email_delivery_input, user_context ) - elif isinstance( - api_options.config.contact_config, - (ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig), - ): - if device_info.phone_number is None: - raise Exception("Should never come here") - log_debug_message( - "Sending passwordless login SMS to %s", device_info.phone_number - ) - sms_input = PasswordlessLoginSMSTemplateVars( - phone_number=device_info.phone_number, - user_input_code=user_input_code, - url_with_link_code=magic_link, - code_life_time=response.code_life_time, - pre_auth_session_id=response.pre_auth_session_id, - tenant_id=tenant_id, - ) - await api_options.sms_delivery.ingredient_interface_impl.send_sms( - sms_input, user_context - ) + return ResendCodePostOkResult() + return ResendCodePostRestartFlowError() async def consume_code_post( @@ -251,65 +504,231 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ ConsumeCodePostOkResult, ConsumeCodePostRestartFlowError, + GeneralErrorResponse, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostExpiredUserInputCodeError, - GeneralErrorResponse, + SignInUpPostNotAllowedResponse, ]: + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_002)", + "SIGN_IN_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_003)", + "LINKING_TO_SESSION_USER_FAILED": { + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_017)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_018)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_019)", + }, + } + + device_info = ( + await api_options.recipe_implementation.list_codes_by_pre_auth_session_id( + tenant_id=tenant_id, + pre_auth_session_id=pre_auth_session_id, + user_context=user_context, + ) + ) + + if not device_info: + return ConsumeCodePostRestartFlowError() + + recipe_id = "passwordless" + account_info = AccountInfo( + phone_number=device_info.phone_number, email=device_info.email + ) + + async def check_credentials(_: str): + nonlocal check_credentials_response + if check_credentials_response is None: + check_credentials_response = ( + await api_options.recipe_implementation.check_code( + pre_auth_session_id=pre_auth_session_id, + device_id=device_id, + user_input_code=user_input_code, + link_code=link_code, + tenant_id=tenant_id, + user_context=user_context, + ) + ) + return isinstance(check_credentials_response, CheckCodeOkResult) + + check_credentials_response = None + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + email=account_info.email, + phone_number=account_info.phone_number, + third_party=None, + recipe_id=recipe_id, + user_context=user_context, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials, + ) + ) + + ev_instance = EmailVerificationRecipe.get_instance_optional() + if account_info.email and session and ev_instance: + session_user = await get_user(session.get_user_id(), user_context) + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + login_method = next( + ( + lm + for lm in session_user.login_methods + if lm.recipe_user_id.get_as_string() + == session.get_recipe_user_id().get_as_string() + ), + None, + ) + if login_method is None: + raise UnauthorisedError( + "Session user and session recipeUserId is inconsistent", + ) + + if ( + login_method.has_same_email_as(account_info.email) + and not login_method.verified + ): + if await check_credentials(tenant_id): + token_response = await ev_instance.recipe_implementation.create_email_verification_token( + tenant_id=tenant_id, + recipe_user_id=login_method.recipe_user_id, + email=account_info.email, + user_context=user_context, + ) + if isinstance(token_response, CreateEmailVerificationTokenOkResult): + await ev_instance.recipe_implementation.verify_email_using_token( + tenant_id=tenant_id, + token=token_response.token, + attempt_account_linking=False, + user_context=user_context, + ) + + factor_id = ( + FactorIds.OTP_EMAIL + if device_info.email and user_input_code + else ( + FactorIds.LINK_EMAIL + if device_info.email + else (FactorIds.OTP_PHONE if user_input_code else FactorIds.LINK_PHONE) + ) + ) + + is_sign_up = authenticating_user is None + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=device_info.email, + phone_number=device_info.phone_number, + ), + factor_ids=[factor_id], + authenticating_user=( + authenticating_user.user if authenticating_user else None + ), + is_sign_up=is_sign_up, + is_verified=( + authenticating_user.login_method.verified + if authenticating_user and authenticating_user.login_method + else True + ), + sign_in_verifies_login_method=True, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + user_context=user_context, + session=session, + ) + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpPostNotAllowedResponse(reason=reason) + + if check_credentials_response is not None: + if not isinstance(check_credentials_response, CheckCodeOkResult): + return check_credentials_response + response = await api_options.recipe_implementation.consume_code( pre_auth_session_id=pre_auth_session_id, - user_input_code=user_input_code, device_id=device_id, + user_input_code=user_input_code, link_code=link_code, + session=session, tenant_id=tenant_id, user_context=user_context, ) - if isinstance(response, ConsumeCodeExpiredUserInputCodeError): - return ConsumeCodePostExpiredUserInputCodeError( - failed_code_input_attempt_count=response.failed_code_input_attempt_count, - maximum_code_input_attempts=response.maximum_code_input_attempts, - ) + if isinstance(response, ConsumeCodeRestartFlowError): + return ConsumeCodePostRestartFlowError() if isinstance(response, ConsumeCodeIncorrectUserInputCodeError): return ConsumeCodePostIncorrectUserInputCodeError( - failed_code_input_attempt_count=response.failed_code_input_attempt_count, - maximum_code_input_attempts=response.maximum_code_input_attempts, + response.failed_code_input_attempt_count, + response.maximum_code_input_attempts, ) - if isinstance(response, ConsumeCodeRestartFlowError): - return ConsumeCodePostRestartFlowError() - - user = response.user - - if user.email is not None: - ev_instance = EmailVerificationRecipe.get_instance_optional() - if ev_instance is not None: - token_response = await ev_instance.recipe_implementation.create_email_verification_token( - RecipeUserId(user.user_id), user.email, tenant_id, user_context - ) + if isinstance(response, ConsumeCodeExpiredUserInputCodeError): + return ConsumeCodePostExpiredUserInputCodeError( + response.failed_code_input_attempt_count, + response.maximum_code_input_attempts, + ) + if not isinstance(response, ConsumeCodeOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[response.reason] + return SignInUpPostNotAllowedResponse(reason=reason) - if isinstance(token_response, CreateEmailVerificationTokenOkResult): - await ev_instance.recipe_implementation.verify_email_using_token( - token_response.token, tenant_id, True, user_context - ) + authenticating_user_input: AccountLinkingUser + if response.user: + authenticating_user_input = response.user + elif authenticating_user: + authenticating_user_input = authenticating_user.user + else: + raise Exception("Should never come here") + recipe_user_id_input: RecipeUserId + if response.recipe_user_id: + recipe_user_id_input = response.recipe_user_id + elif authenticating_user: + assert authenticating_user.login_method is not None + recipe_user_id_input = authenticating_user.login_method.recipe_user_id + else: + raise Exception("Should never come here") - session = await create_new_session( - request=api_options.request, + post_auth_checks_result = await post_auth_checks( + factor_id=factor_id, + is_sign_up=is_sign_up, + authenticated_user=authenticating_user_input, + recipe_user_id=recipe_user_id_input, tenant_id=tenant_id, - recipe_user_id=RecipeUserId(user.user_id), - access_token_payload={}, - session_data_in_database={}, user_context=user_context, + session=session, + request=api_options.request, ) + if not isinstance(post_auth_checks_result, PostAuthChecksOkResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + return ConsumeCodePostOkResult( - created_new_user=response.created_new_user, - user=response.user, - session=session, + created_new_recipe_user=response.created_new_recipe_user, + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, ) async def email_exists_get( @@ -319,10 +738,21 @@ async def email_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[EmailExistsGetOkResult, GeneralErrorResponse]: - response = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, ) - return EmailExistsGetOkResult(exists=response is not None) + user_exists = any( + any( + lm.recipe_id == "passwordless" and lm.has_same_email_as(email) + for lm in u.login_methods + ) + for u in users + ) + + return EmailExistsGetOkResult(exists=user_exists) async def phone_number_exists_get( self, @@ -331,7 +761,10 @@ async def phone_number_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[PhoneNumberExistsGetOkResult, GeneralErrorResponse]: - response = await api_options.recipe_implementation.get_user_by_phone_number( - phone_number, tenant_id, user_context + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(phone_number=phone_number), + do_union_of_account_info=False, + user_context=user_context, ) - return PhoneNumberExistsGetOkResult(exists=response is not None) + return PhoneNumberExistsGetOkResult(exists=len(users) > 0) diff --git a/supertokens_python/recipe/passwordless/api/resend_code.py b/supertokens_python/recipe/passwordless/api/resend_code.py index 8afd885db..619d8e687 100644 --- a/supertokens_python/recipe/passwordless/api/resend_code.py +++ b/supertokens_python/recipe/passwordless/api/resend_code.py @@ -14,6 +14,7 @@ from typing import Any, Dict from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.utils import send_200_response @@ -39,7 +40,16 @@ async def resend_code( pre_auth_session_id = body["preAuthSessionId"] device_id = body["deviceId"] + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, pre_auth_session_id, session, tenant_id, api_options, user_context ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/asyncio/__init__.py b/supertokens_python/recipe/passwordless/asyncio/__init__.py index 4632f6109..8b0a0bdfb 100644 --- a/supertokens_python/recipe/passwordless/asyncio/__init__.py +++ b/supertokens_python/recipe/passwordless/asyncio/__init__.py @@ -11,10 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from supertokens_python import get_request_from_user_context +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.passwordless.interfaces import ( + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeOkResult, @@ -23,8 +28,8 @@ CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, RevokeAllCodesOkResult, RevokeCodeOkResult, UpdateUserEmailAlreadyExistsError, @@ -37,8 +42,9 @@ DeviceType, EmailTemplateVars, SMSTemplateVars, - User, ) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import RecipeUserId async def create_code( @@ -46,6 +52,7 @@ async def create_code( email: Union[None, str] = None, phone_number: Union[None, str] = None, user_input_code: Union[None, str] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> CreateCodeOkResult: if user_context is None: @@ -55,6 +62,7 @@ async def create_code( phone_number=phone_number, user_input_code=user_input_code, tenant_id=tenant_id, + session=session, user_context=user_context, ) @@ -85,12 +93,14 @@ async def consume_code( user_input_code: Union[str, None] = None, device_id: Union[str, None] = None, link_code: Union[str, None] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ConsumeCodeOkResult, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: if user_context is None: user_context = {} @@ -100,48 +110,13 @@ async def consume_code( device_id=device_id, link_code=link_code, tenant_id=tenant_id, - user_context=user_context, - ) - - -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id=user_id, user_context=user_context - ) - - -async def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return ( - await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_email( - email=email, - tenant_id=tenant_id, - user_context=user_context, - ) - ) - - -async def get_user_by_phone_number( - tenant_id: str, phone_number: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_phone_number( - phone_number=phone_number, - tenant_id=tenant_id, + session=session, user_context=user_context, ) async def update_user( - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None] = None, phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -150,34 +125,41 @@ async def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: if user_context is None: user_context = {} return await PasswordlessRecipe.get_instance().recipe_implementation.update_user( - user_id=user_id, + recipe_user_id=recipe_user_id, email=email, phone_number=phone_number, user_context=user_context, ) -async def delete_email_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - if user_context is None: - user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.delete_email_for_user( - user_id=user_id, user_context=user_context - ) - - -async def delete_phone_number_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: +async def check_code( + tenant_id: str, + pre_auth_session_id: str, + user_input_code: Union[str, None] = None, + device_id: Union[str, None] = None, + link_code: Union[str, None] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, +]: if user_context is None: user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.delete_phone_number_for_user( - user_id=user_id, user_context=user_context + return await PasswordlessRecipe.get_instance().recipe_implementation.check_code( + pre_auth_session_id=pre_auth_session_id, + user_input_code=user_input_code, + device_id=device_id, + link_code=link_code, + tenant_id=tenant_id, + user_context=user_context, ) @@ -281,6 +263,7 @@ async def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: if user_context is None: @@ -290,6 +273,7 @@ async def signinup( phone_number=phone_number, tenant_id=tenant_id, user_context=user_context, + session=session, ) diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 838890d7e..1470f657c 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -14,14 +14,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from typing_extensions import Literal +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import ( + APIResponse, + AccountLinkingUser, + GeneralErrorResponse, + RecipeUserId, +) from ...supertokens import AppInfo @@ -31,7 +37,6 @@ PasswordlessLoginEmailTemplateVars, PasswordlessLoginSMSTemplateVars, SMSDeliveryIngredient, - User, ) from .utils import PasswordlessConfig @@ -84,10 +89,41 @@ class CreateNewCodeForDeviceUserInputCodeAlreadyUsedError: pass +class ConsumedDevice: + def __init__( + self, + pre_auth_session_id: str, + failed_code_input_attempt_count: int, + email: Optional[str] = None, + phone_number: Optional[str] = None, + ): + self.pre_auth_session_id = pre_auth_session_id + self.failed_code_input_attempt_count = failed_code_input_attempt_count + self.email = email + self.phone_number = phone_number + + @staticmethod + def from_json(json: Dict[str, Any]) -> ConsumedDevice: + return ConsumedDevice( + pre_auth_session_id=json["preAuthSessionId"], + failed_code_input_attempt_count=json["failedCodeInputAttemptCount"], + email=json["email"], + phone_number=json["phoneNumber"], + ) + + class ConsumeCodeOkResult: - def __init__(self, created_new_user: bool, user: User): - self.created_new_user = created_new_user + def __init__( + self, + created_new_recipe_user: bool, + user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + consumed_device: ConsumedDevice, + ): + self.created_new_recipe_user = created_new_recipe_user self.user = user + self.recipe_user_id = recipe_user_id + self.consumed_device = consumed_device class ConsumeCodeIncorrectUserInputCodeError: @@ -126,22 +162,42 @@ class UpdateUserPhoneNumberAlreadyExistsError: pass -class DeleteUserInfoOkResult: +class RevokeAllCodesOkResult: pass -class DeleteUserInfoUnknownUserIdError: +class RevokeCodeOkResult: pass -class RevokeAllCodesOkResult: +class CheckCodeOkResult: + def __init__(self, consumed_device: ConsumedDevice): + self.status = "OK" + self.consumed_device = consumed_device + + +class CheckCodeIncorrectUserInputCodeError(ConsumeCodeIncorrectUserInputCodeError): pass -class RevokeCodeOkResult: +class CheckCodeExpiredUserInputCodeError(ConsumeCodeExpiredUserInputCodeError): + pass + + +class CheckCodeRestartFlowError(ConsumeCodeRestartFlowError): pass +class EmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + +class PhoneNumberChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + class RecipeInterface(ABC): def __init__(self): pass @@ -152,6 +208,7 @@ async def create_code( email: Union[None, str], phone_number: Union[None, str], user_input_code: Union[None, str], + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: @@ -178,6 +235,7 @@ async def consume_code( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -185,31 +243,31 @@ async def consume_code( ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: pass @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass - - @abstractmethod - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass - - @abstractmethod - async def get_user_by_phone_number( - self, phone_number: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def check_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, + ]: pass @abstractmethod async def update_user( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], phone_number: Union[str, None], user_context: Dict[str, Any], @@ -218,21 +276,11 @@ async def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: pass - @abstractmethod - async def delete_email_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - pass - - @abstractmethod - async def delete_phone_number_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - pass - @abstractmethod async def revoke_all_codes( self, @@ -337,21 +385,21 @@ def to_json(self): class ConsumeCodePostOkResult(APIResponse): status: str = "OK" - def __init__(self, created_new_user: bool, user: User, session: SessionContainer): - self.created_new_user = created_new_user + def __init__( + self, + created_new_recipe_user: bool, + user: AccountLinkingUser, + session: SessionContainer, + ): + self.created_new_recipe_user = created_new_recipe_user self.user = user self.session = session def to_json(self): - user = {"id": self.user.user_id, "time_joined": self.user.time_joined} - if self.user.email is not None: - user = {**user, "email": self.user.email} - if self.user.phone_number is not None: - user = {**user, "phoneNumber": self.user.phone_number} return { "status": self.status, - "createdNewUser": self.created_new_user, - "user": user, + "user": self.user.to_json(), + "createdNewRecipeUser": self.created_new_recipe_user, } @@ -416,6 +464,16 @@ def to_json(self): return {"status": self.status, "exists": self.exists} +class SignInUpPostNotAllowedResponse(APIResponse): + status: str = "SIGN_IN_UP_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "reason": self.reason} + + class APIInterface: def __init__(self): self.disable_create_code_post = False @@ -429,10 +487,13 @@ async def create_code_post( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[CreateCodePostOkResult, GeneralErrorResponse]: + ) -> Union[ + CreateCodePostOkResult, SignInUpPostNotAllowedResponse, GeneralErrorResponse + ]: pass @abstractmethod @@ -440,6 +501,7 @@ async def resend_code_post( self, device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -455,6 +517,7 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -464,6 +527,7 @@ async def consume_code_post( GeneralErrorResponse, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostExpiredUserInputCodeError, + SignInUpPostNotAllowedResponse, ]: pass diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index 93b7ef3dd..1d59d3483 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -15,17 +15,35 @@ from os import environ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union, Optional +from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery import SMSDeliveryIngredient from supertokens_python.querier import Querier +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import ( + FactorIds, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetFactorsSetupForUserFromOtherRecipesFunc, + GetPhoneNumbersForFactorsFromOtherRecipesFunc, + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.passwordless.types import ( PasswordlessIngredients, PasswordlessLoginSMSTemplateVars, ) from typing_extensions import Literal +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import AccountLinkingUser, RecipeUserId + from .api import ( consume_code, create_code, @@ -56,11 +74,6 @@ OverrideConfig, validate_and_normalise_user_input, ) -from ..emailverification.interfaces import ( - GetEmailForUserIdOkResult, - EmailDoesNotExistError, - UnknownUserIdError, -) from ...post_init_callbacks import PostSTInitCallbacks if TYPE_CHECKING: @@ -141,7 +154,251 @@ def __init__( ) def callback(): - pass + mfa_instance = MultiFactorAuthRecipe.get_instance() + all_factors = [ + FactorIds.OTP_EMAIL, + FactorIds.LINK_EMAIL, + FactorIds.OTP_PHONE, + FactorIds.LINK_PHONE, + ] + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return all_factors + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + + async def get_factors_setup_for_user( + user: AccountLinkingUser, _: Dict[str, Any] + ) -> List[str]: + def is_factor_setup_for_user( + user: AccountLinkingUser, factor_id: str + ) -> bool: + for login_method in user.login_methods: + if login_method.recipe_id != "passwordless": + continue + + if login_method.email is not None and not is_fake_email( + login_method.email + ): + if factor_id in [ + FactorIds.OTP_EMAIL, + FactorIds.LINK_EMAIL, + ]: + return True + + if login_method.phone_number is not None: + if factor_id in [ + FactorIds.OTP_PHONE, + FactorIds.LINK_PHONE, + ]: + return True + return False + + return [ + factor_id + for factor_id in all_factors + if is_factor_setup_for_user(user, factor_id) + ] + + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc( + get_factors_setup_for_user + ) + ) + + async def get_emails_for_factor( + user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ]: + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() + + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + + # MAIN LOGIC FOR THE FUNCTION STARTS HERE + non_fake_emails_passwordless = [ + lm.email + for lm in ordered_login_methods + if lm.recipe_id == "passwordless" + and lm.email is not None + and not is_fake_email(lm.email) + ] + + if not non_fake_emails_passwordless: + # This factor is not set up for email-based factors. + # We check for emails from other loginMethods and return those. + emails_result = [] + if ( + session_login_method.email is not None + and not is_fake_email(session_login_method.email) + ): + emails_result = [session_login_method.email] + + emails_result.extend( + [ + lm.email + for lm in ordered_login_methods + if lm.email is not None + and not is_fake_email(lm.email) + and lm.email not in emails_result + ] + ) + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.OTP_EMAIL] = emails_result + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.LINK_EMAIL + ] = emails_result + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + elif len(non_fake_emails_passwordless) == 1: + # Return just this email to avoid creating more loginMethods + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.OTP_EMAIL + ] = non_fake_emails_passwordless + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.LINK_EMAIL + ] = non_fake_emails_passwordless + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + + # Return all emails with passwordless login method, prioritizing session's email + emails_result = [] + if ( + session_login_method.email is not None + and session_login_method.email in non_fake_emails_passwordless + ): + emails_result = [session_login_method.email] + + emails_result.extend( + [ + email + for email in non_fake_emails_passwordless + if email not in emails_result + ] + ) + + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.OTP_EMAIL] = emails_result + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.LINK_EMAIL] = emails_result + + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + + mfa_instance.add_func_to_get_emails_for_factor_from_other_recipes( + GetEmailsForFactorFromOtherRecipesFunc(get_emails_for_factor) + ) + + async def get_phone_numbers_for_factors( + user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + ]: + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return ( + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult() + ) + + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + + phone_numbers = [ + lm.phone_number + for lm in ordered_login_methods + if lm.recipe_id == "passwordless" + and lm.phone_number is not None + ] + + if not phone_numbers: + phones_result = [] + if session_login_method.phone_number is not None: + phones_result = [session_login_method.phone_number] + + phones_result.extend( + [ + lm.phone_number + for lm in ordered_login_methods + if lm.phone_number is not None + and lm.phone_number not in phones_result + ] + ) + elif len(phone_numbers) == 1: + phones_result = phone_numbers + else: + phones_result = [] + if ( + session_login_method.phone_number is not None + and session_login_method.phone_number in phone_numbers + ): + phones_result = [session_login_method.phone_number] + phones_result.extend( + [ + phone + for phone in phone_numbers + if phone not in phones_result + ] + ) + + factor_id_to_phone_number_map = {} + if FactorIds.OTP_PHONE in all_factors: + factor_id_to_phone_number_map[ + FactorIds.OTP_PHONE + ] = phones_result + if FactorIds.LINK_PHONE in all_factors: + factor_id_to_phone_number_map[ + FactorIds.LINK_PHONE + ] = phones_result + + return GetPhoneNumbersForFactorsOkResult( + factor_id_to_phone_number_map=factor_id_to_phone_number_map + ) + + mfa_instance.add_func_to_get_phone_numbers_for_factors_from_other_recipes( + GetPhoneNumbersForFactorsFromOtherRecipesFunc( + get_phone_numbers_for_factors + ) + ) + + mt_recipe = MultitenancyRecipe.get_instance_optional() + if mt_recipe is not None: + for factor_id in all_factors: + mt_recipe.all_available_first_factors.append(factor_id) PostSTInitCallbacks.add_post_init_callback(callback) @@ -327,6 +584,7 @@ async def create_magic_link( tenant_id=tenant_id, user_input_code=user_input_code, user_context=user_context, + session=None, ) app_info = self.get_app_info() @@ -348,6 +606,7 @@ async def signinup( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], ) -> ConsumeCodeOkResult: @@ -357,6 +616,7 @@ async def signinup( user_input_code=None, tenant_id=tenant_id, user_context=user_context, + session=session, ) consume_code_result = await self.recipe_implementation.consume_code( link_code=code_info.link_code, @@ -365,19 +625,8 @@ async def signinup( user_input_code=code_info.user_input_code, tenant_id=tenant_id, user_context=user_context, + session=session, ) if isinstance(consume_code_result, ConsumeCodeOkResult): return consume_code_result raise Exception("Failed to create user. Please retry") - - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError]: - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - if user_info.email is not None: - return GetEmailForUserIdOkResult(user_info.email) - return EmailDoesNotExistError() - return UnknownUserIdError() diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 19a75b97c..b5d030171 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -13,25 +13,32 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, +) from supertokens_python.querier import Querier - -from .types import DeviceCode, DeviceType, User - from supertokens_python.normalised_url_path import NormalisedURLPath - -from .interfaces import ( +from supertokens_python.recipe.passwordless.types import DeviceCode, DeviceType +from supertokens_python.recipe.passwordless.interfaces import ( + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeOkResult, ConsumeCodeRestartFlowError, + ConsumedDevice, CreateCodeOkResult, CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, RecipeInterface, RevokeAllCodesOkResult, RevokeCodeOkResult, @@ -40,6 +47,11 @@ UpdateUserPhoneNumberAlreadyExistsError, UpdateUserUnknownUserIdError, ) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.utils import log_debug_message +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe class RecipeImplementation(RecipeInterface): @@ -47,34 +59,163 @@ def __init__(self, querier: Querier): super().__init__() self.querier = querier + async def consume_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + session: Optional[SessionContainer], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + ConsumeCodeOkResult, + ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeExpiredUserInputCodeError, + ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, + ]: + input_dict = { + "preAuthSessionId": pre_auth_session_id, + } + if link_code is not None: + input_dict["linkCode"] = link_code + else: + if user_input_code is None or device_id is None: + return ConsumeCodeRestartFlowError() + input_dict["userInputCode"] = user_input_code + input_dict["deviceId"] = device_id + + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/consume"), + input_dict, + user_context=user_context, + ) + + if response["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": + return ConsumeCodeIncorrectUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "EXPIRED_USER_INPUT_CODE_ERROR": + return ConsumeCodeExpiredUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "RESTART_FLOW_ERROR": + return ConsumeCodeRestartFlowError() + + # status == "OK" + + log_debug_message("Passwordless.consumeCode code consumed OK") + + recipe_user_id = RecipeUserId(response["recipeUserId"]) + + updated_user = AccountLinkingUser.from_json(response["user"]) + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=updated_user, + recipe_user_id=recipe_user_id, + session=session, + user_context=user_context, + ) + + if isinstance(link_result, LinkingToSessionUserFailedError): + return link_result + + updated_user = link_result.user + + response["user"] = updated_user + + return ConsumeCodeOkResult( + user=updated_user, + recipe_user_id=recipe_user_id, + consumed_device=ConsumedDevice.from_json(response["consumedDevice"]), + created_new_recipe_user=response["createdNewUser"], + ) + + async def check_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, + ]: + input_dict = { + "preAuthSessionId": pre_auth_session_id, + } + if link_code is not None: + input_dict["linkCode"] = link_code + else: + if user_input_code is None or device_id is None: + return CheckCodeRestartFlowError() + input_dict["userInputCode"] = user_input_code + input_dict["deviceId"] = device_id + + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/check"), + input_dict, + user_context=user_context, + ) + + if response["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": + return CheckCodeIncorrectUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "EXPIRED_USER_INPUT_CODE_ERROR": + return CheckCodeExpiredUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "RESTART_FLOW_ERROR": + return CheckCodeRestartFlowError() + + # status == "OK" + log_debug_message("Passwordless.checkCode code verified") + + return CheckCodeOkResult( + consumed_device=ConsumedDevice.from_json(response["consumedDevice"]) + ) + async def create_code( self, email: Union[None, str], phone_number: Union[None, str], user_input_code: Union[None, str], + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: - data: Dict[str, Any] = {} - if user_input_code is not None: - data = {**data, "userInputCode": user_input_code} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "phoneNumber": phone_number} - result = await self.querier.send_post_request( + input_dict = {} + if email: + input_dict["email"] = email + if phone_number: + input_dict["phoneNumber"] = phone_number + if user_input_code: + input_dict["userInputCode"] = user_input_code + + response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup/code"), - data, + input_dict, user_context=user_context, ) return CreateCodeOkResult( - pre_auth_session_id=result["preAuthSessionId"], - code_id=result["codeId"], - device_id=result["deviceId"], - user_input_code=result["userInputCode"], - link_code=result["linkCode"], - time_created=result["timeCreated"], - code_life_time=result["codeLifetime"], + pre_auth_session_id=response["preAuthSessionId"], + code_id=response["codeId"], + device_id=response["deviceId"], + user_input_code=response["userInputCode"], + link_code=response["linkCode"], + code_life_time=response["codeLifeTime"], + time_created=response["timeCreated"], ) async def create_new_code_for_device( @@ -110,226 +251,43 @@ async def create_new_code_for_device( time_created=result["timeCreated"], ) - async def consume_code( - self, - pre_auth_session_id: str, - user_input_code: Union[str, None], - device_id: Union[str, None], - link_code: Union[str, None], - tenant_id: str, - user_context: Dict[str, Any], - ) -> Union[ - ConsumeCodeOkResult, - ConsumeCodeIncorrectUserInputCodeError, - ConsumeCodeExpiredUserInputCodeError, - ConsumeCodeRestartFlowError, - ]: - data = {"preAuthSessionId": pre_auth_session_id} - if device_id is not None: - data = {**data, "deviceId": device_id, "userInputCode": user_input_code} - else: - data = {**data, "linkCode": link_code} - result = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/consume"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - email = None - phone_number = None - if "email" in result["user"]: - email = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number = result["user"]["phoneNumber"] - user = User( - user_id=result["user"]["id"], - email=email, - phone_number=phone_number, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], - ) - return ConsumeCodeOkResult(result["createdNewUser"], user) - if result["status"] == "RESTART_FLOW_ERROR": - return ConsumeCodeRestartFlowError() - if result["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": - return ConsumeCodeIncorrectUserInputCodeError( - failed_code_input_attempt_count=result["failedCodeInputAttemptCount"], - maximum_code_input_attempts=result["maximumCodeInputAttempts"], - ) - return ConsumeCodeExpiredUserInputCodeError( - failed_code_input_attempt_count=result["failedCodeInputAttemptCount"], - maximum_code_input_attempts=result["maximumCodeInputAttempts"], - ) - - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"userId": user_id} + async def list_codes_by_device_id( + self, device_id: str, tenant_id: str, user_context: Dict[str, Any] + ) -> Union[DeviceType, None]: + param = {"deviceId": device_id} result = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), + NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), param, user_context=user_context, ) - if result["status"] == "OK": + if "devices" in result and len(result["devices"]) == 1: + codes: List[DeviceCode] = [] + if "code" in result["devices"][0]: + for code in result["devices"][0]: + codes.append( + DeviceCode( + code_id=code["codeId"], + time_created=code["timeCreated"], + code_life_time=code["codeLifetime"], + ) + ) email = None phone_number = None - if "email" in result["user"]: - email = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], + if "email" in result["devices"][0]: + email = result["devices"][0]["email"] + if "phoneNumber" in result["devices"][0]: + phone_number = result["devices"][0]["phoneNumber"] + return DeviceType( + pre_auth_session_id=result["devices"][0]["preAuthSessionId"], + failed_code_input_attempt_count=result["devices"][0][ + "failedCodeInputAttemptCount" + ], + codes=codes, email=email, phone_number=phone_number, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], ) return None - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"email": email} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - param, - user_context=user_context, - ) - if result["status"] == "OK": - email_resp = None - phone_number_resp = None - if "email" in result["user"]: - email_resp = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number_resp = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], - email=email_resp, - phone_number=phone_number_resp, - tenant_ids=result["user"]["tenantIds"], - time_joined=result["user"]["timeJoined"], - ) - return None - - async def get_user_by_phone_number( - self, phone_number: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"phoneNumber": phone_number} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - param, - user_context=user_context, - ) - if result["status"] == "OK": - email_resp = None - phone_number_resp = None - if "email" in result["user"]: - email_resp = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number_resp = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], - email=email_resp, - phone_number=phone_number_resp, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], - ) - return None - - async def update_user( - self, - user_id: str, - email: Union[str, None], - phone_number: Union[str, None], - user_context: Dict[str, Any], - ) -> Union[ - UpdateUserOkResult, - UpdateUserUnknownUserIdError, - UpdateUserEmailAlreadyExistsError, - UpdateUserPhoneNumberAlreadyExistsError, - ]: - data = {"userId": user_id} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "phoneNumber": phone_number} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return UpdateUserOkResult() - if result["status"] == "UNKNOWN_USER_ID_ERROR": - return UpdateUserUnknownUserIdError() - if result["status"] == "EMAIL_ALREADY_EXISTS_ERROR": - return UpdateUserEmailAlreadyExistsError() - return UpdateUserPhoneNumberAlreadyExistsError() - - async def delete_email_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - data = {"userId": user_id, "email": None} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return DeleteUserInfoOkResult() - if result.get("EMAIL_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - if result.get("PHONE_NUMBER_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - return DeleteUserInfoUnknownUserIdError() - - async def delete_phone_number_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - data = {"userId": user_id, "phoneNumber": None} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return DeleteUserInfoOkResult() - if result.get("EMAIL_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - if result.get("PHONE_NUMBER_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - return DeleteUserInfoUnknownUserIdError() - - async def revoke_all_codes( - self, - email: Union[str, None], - phone_number: Union[str, None], - tenant_id: str, - user_context: Dict[str, Any], - ) -> RevokeAllCodesOkResult: - data: Dict[str, Any] = {} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "email": phone_number} - await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes/remove"), - data, - user_context=user_context, - ) - return RevokeAllCodesOkResult() - - async def revoke_code( - self, code_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> RevokeCodeOkResult: - data = {"codeId": code_id} - await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/remove"), - data, - user_context=user_context, - ) - return RevokeCodeOkResult() - async def list_codes_by_email( self, email: str, tenant_id: str, user_context: Dict[str, Any] ) -> List[DeviceType]: @@ -412,10 +370,10 @@ async def list_codes_by_phone_number( ) return devices - async def list_codes_by_device_id( - self, device_id: str, tenant_id: str, user_context: Dict[str, Any] + async def list_codes_by_pre_auth_session_id( + self, pre_auth_session_id: str, tenant_id: str, user_context: Dict[str, Any] ) -> Union[DeviceType, None]: - param = {"deviceId": device_id} + param = {"preAuthSessionId": pre_auth_session_id} result = await self.querier.send_get_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), param, @@ -449,39 +407,119 @@ async def list_codes_by_device_id( ) return None - async def list_codes_by_pre_auth_session_id( - self, pre_auth_session_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[DeviceType, None]: - param = {"preAuthSessionId": pre_auth_session_id} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), - param, + async def revoke_all_codes( + self, + email: Union[str, None], + phone_number: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> RevokeAllCodesOkResult: + data: Dict[str, Any] = {} + if email is not None: + data = {**data, "email": email} + if phone_number is not None: + data = {**data, "email": phone_number} + await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes/remove"), + data, user_context=user_context, ) - if "devices" in result and len(result["devices"]) == 1: - codes: List[DeviceCode] = [] - if "code" in result["devices"][0]: - for code in result["devices"][0]: - codes.append( - DeviceCode( - code_id=code["codeId"], - time_created=code["timeCreated"], - code_life_time=code["codeLifetime"], - ) + return RevokeAllCodesOkResult() + + async def revoke_code( + self, code_id: str, tenant_id: str, user_context: Dict[str, Any] + ) -> RevokeCodeOkResult: + data = {"codeId": code_id} + await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/remove"), + data, + user_context=user_context, + ) + return RevokeCodeOkResult() + + async def update_user( + self, + recipe_user_id: RecipeUserId, + email: Union[str, None], + phone_number: Union[str, None], + user_context: Dict[str, Any], + ) -> Union[ + UpdateUserOkResult, + UpdateUserUnknownUserIdError, + UpdateUserEmailAlreadyExistsError, + UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + ]: + account_linking = AccountLinkingRecipe.get_instance() + if email: + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UpdateUserUnknownUserIdError() + + ev_instance = EmailVerificationRecipe.get_instance_optional() + is_email_verified = False + if ev_instance: + is_email_verified = ( + await ev_instance.recipe_implementation.is_email_verified( + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, ) - email = None - phone_number = None - if "email" in result["devices"][0]: - email = result["devices"][0]["email"] - if "phoneNumber" in result["devices"][0]: - phone_number = result["devices"][0]["phoneNumber"] - return DeviceType( - pre_auth_session_id=result["devices"][0]["preAuthSessionId"], - failed_code_input_attempt_count=result["devices"][0][ - "failedCodeInputAttemptCount" - ], - codes=codes, - email=email, - phone_number=phone_number, + ) + + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_email_verified, + new_email=email, + session=None, + user_context=user_context, ) - return None + if not is_email_change_allowed.allowed: + return EmailChangeNotAllowedError( + reason=( + "New email cannot be applied to existing account because of account takeover risks." + if is_email_change_allowed.reason == "ACCOUNT_TAKEOVER_RISK" + else "New email cannot be applied to existing account because there is another primary user with the same email address." + ), + ) + + input_dict = { + "recipeUserId": recipe_user_id.get_as_string(), + } + if email: + input_dict = {**input_dict, "email": email} + if phone_number: + input_dict = {**input_dict, "phoneNumber": phone_number} + + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + input_dict, + user_context=user_context, + ) + if response["status"] == "UNKNOWN_USER_ID_ERROR": + return UpdateUserUnknownUserIdError() + elif response["status"] == "EMAIL_ALREADY_EXISTS_ERROR": + return UpdateUserEmailAlreadyExistsError() + elif response["status"] == "PHONE_NUMBER_ALREADY_EXISTS_ERROR": + return UpdateUserPhoneNumberAlreadyExistsError() + elif response["status"] == "EMAIL_CHANGE_NOT_ALLOWED_ERROR": + return EmailChangeNotAllowedError( + reason=response["reason"], + ) + elif response["status"] == "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR": + return PhoneNumberChangeNotAllowedError( + reason=response["reason"], + ) + + # status is OK + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UpdateUserUnknownUserIdError() + + await account_linking.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + return UpdateUserOkResult() diff --git a/supertokens_python/recipe/passwordless/syncio/__init__.py b/supertokens_python/recipe/passwordless/syncio/__init__.py index 7210f8153..5618a64d3 100644 --- a/supertokens_python/recipe/passwordless/syncio/__init__.py +++ b/supertokens_python/recipe/passwordless/syncio/__init__.py @@ -11,10 +11,11 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union - +from typing import Any, Dict, List, Optional, Union from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.passwordless import asyncio +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.passwordless.interfaces import ( ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, @@ -24,21 +25,25 @@ CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, RevokeAllCodesOkResult, RevokeCodeOkResult, UpdateUserEmailAlreadyExistsError, UpdateUserOkResult, UpdateUserPhoneNumberAlreadyExistsError, UpdateUserUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ) from supertokens_python.recipe.passwordless.types import ( DeviceType, - PasswordlessLoginEmailTemplateVars, - PasswordlessLoginSMSTemplateVars, - User, + EmailTemplateVars, + SMSTemplateVars, ) +from supertokens_python.types import RecipeUserId def create_code( @@ -46,14 +51,18 @@ def create_code( email: Union[None, str] = None, phone_number: Union[None, str] = None, user_input_code: Union[None, str] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> CreateCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.create_code( tenant_id, email=email, phone_number=phone_number, user_input_code=user_input_code, + session=session, user_context=user_context, ) ) @@ -69,6 +78,8 @@ def create_new_code_for_device( CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.create_new_code_for_device( tenant_id, @@ -85,13 +96,17 @@ def consume_code( user_input_code: Union[str, None] = None, device_id: Union[str, None] = None, link_code: Union[str, None] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ConsumeCodeOkResult, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.consume_code( tenant_id, @@ -99,39 +114,14 @@ def consume_code( user_input_code=user_input_code, device_id=device_id, link_code=link_code, + session=session, user_context=user_context, ) ) -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - return sync(asyncio.get_user_by_id(user_id=user_id, user_context=user_context)) - - -def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - return sync( - asyncio.get_user_by_email(tenant_id, email=email, user_context=user_context) - ) - - -def get_user_by_phone_number( - tenant_id: str, - phone_number: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[User, None]: - return sync( - asyncio.get_user_by_phone_number( - tenant_id=tenant_id, phone_number=phone_number, user_context=user_context - ) - ) - - def update_user( - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None] = None, phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -140,10 +130,14 @@ def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.update_user( - user_id=user_id, + recipe_user_id=recipe_user_id, email=email, phone_number=phone_number, user_context=user_context, @@ -151,28 +145,14 @@ def update_user( ) -def delete_email_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - return sync( - asyncio.delete_email_for_user(user_id=user_id, user_context=user_context) - ) - - -def delete_phone_number_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - return sync( - asyncio.delete_phone_number_for_user(user_id=user_id, user_context=user_context) - ) - - def revoke_all_codes( tenant_id: str, email: Union[str, None] = None, phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> RevokeAllCodesOkResult: + if user_context is None: + user_context = {} return sync( asyncio.revoke_all_codes( tenant_id, email=email, phone_number=phone_number, user_context=user_context @@ -183,6 +163,8 @@ def revoke_all_codes( def revoke_code( tenant_id: str, code_id: str, user_context: Union[None, Dict[str, Any]] = None ) -> RevokeCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.revoke_code(tenant_id, code_id=code_id, user_context=user_context) ) @@ -191,6 +173,8 @@ def revoke_code( def list_codes_by_email( tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None ) -> List[DeviceType]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_email(tenant_id, email=email, user_context=user_context) ) @@ -199,6 +183,8 @@ def list_codes_by_email( def list_codes_by_phone_number( tenant_id: str, phone_number: str, user_context: Union[None, Dict[str, Any]] = None ) -> List[DeviceType]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_phone_number( tenant_id, phone_number=phone_number, user_context=user_context @@ -211,6 +197,8 @@ def list_codes_by_device_id( device_id: str, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[DeviceType, None]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_device_id( tenant_id=tenant_id, device_id=device_id, user_context=user_context @@ -223,6 +211,8 @@ def list_codes_by_pre_auth_session_id( pre_auth_session_id: str, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[DeviceType, None]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_pre_auth_session_id( tenant_id=tenant_id, @@ -238,6 +228,8 @@ def create_magic_link( phone_number: Union[str, None], user_context: Union[None, Dict[str, Any]] = None, ) -> str: + if user_context is None: + user_context = {} return sync( asyncio.create_magic_link( tenant_id=tenant_id, @@ -252,27 +244,62 @@ def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.signinup( tenant_id=tenant_id, email=email, phone_number=phone_number, user_context=user_context, + session=session, ) ) def send_email( - input_: PasswordlessLoginEmailTemplateVars, + input_: EmailTemplateVars, user_context: Union[None, Dict[str, Any]] = None, -) -> None: +): + if user_context is None: + user_context = {} return sync(asyncio.send_email(input_, user_context)) def send_sms( - input_: PasswordlessLoginSMSTemplateVars, + input_: SMSTemplateVars, user_context: Union[None, Dict[str, Any]] = None, -) -> None: +): + if user_context is None: + user_context = {} return sync(asyncio.send_sms(input_, user_context)) + + +def check_code( + tenant_id: str, + pre_auth_session_id: str, + user_input_code: Union[str, None] = None, + device_id: Union[str, None] = None, + link_code: Union[str, None] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, +]: + if user_context is None: + user_context = {} + return sync( + asyncio.check_code( + tenant_id, + pre_auth_session_id=pre_auth_session_id, + user_input_code=user_input_code, + device_id=device_id, + link_code=link_code, + user_context=user_context, + ) + ) diff --git a/supertokens_python/recipe/passwordless/types.py b/supertokens_python/recipe/passwordless/types.py index 38f7bcb82..03bc09cd6 100644 --- a/supertokens_python/recipe/passwordless/types.py +++ b/supertokens_python/recipe/passwordless/types.py @@ -25,32 +25,6 @@ ) -class User: - def __init__( - self, - user_id: str, - email: Union[str, None], - phone_number: Union[str, None], - time_joined: int, - tenant_ids: List[str], - ): - self.user_id = user_id - self.email = email - self.phone_number = phone_number - self.time_joined = time_joined - self.tenant_ids = tenant_ids - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.phone_number == other.phone_number - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - ) - - class DeviceCode: def __init__(self, code_id: str, time_created: str, code_life_time: int): self.code_id = code_id @@ -81,6 +55,7 @@ def __init__( code_life_time: int, pre_auth_session_id: str, email: str, + is_first_factor: bool, user_input_code: Union[str, None] = None, url_with_link_code: Union[str, None] = None, ): @@ -90,6 +65,7 @@ def __init__( self.user_input_code = user_input_code self.url_with_link_code = url_with_link_code self.tenant_id = tenant_id + self.is_first_factor = is_first_factor PasswordlessLoginEmailTemplateVars = CreateAndSendCustomEmailParameters @@ -102,6 +78,7 @@ def __init__( code_life_time: int, pre_auth_session_id: str, phone_number: str, + is_first_factor: bool, user_input_code: Union[str, None] = None, url_with_link_code: Union[str, None] = None, ): @@ -111,6 +88,7 @@ def __init__( self.user_input_code = user_input_code self.url_with_link_code = url_with_link_code self.tenant_id = tenant_id + self.is_first_factor = is_first_factor PasswordlessLoginSMSTemplateVars = CreateAndSendCustomTextMessageParameters diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index 9fb619c53..dcd85b39c 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union, List from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryConfig, @@ -25,6 +25,7 @@ SMSDeliveryConfig, SMSDeliveryConfigWithService, ) +from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.passwordless.types import ( PasswordlessLoginSMSTemplateVars, ) @@ -187,9 +188,9 @@ def validate_and_normalise_user_input( if override is None: override = OverrideConfig() - def get_email_delivery_config() -> EmailDeliveryConfigWithService[ - PasswordlessLoginEmailTemplateVars - ]: + def get_email_delivery_config() -> ( + EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] + ): email_service = email_delivery.service if email_delivery is not None else None if email_service is None: @@ -202,9 +203,9 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ return EmailDeliveryConfigWithService(email_service, override=override) - def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ - PasswordlessLoginSMSTemplateVars - ]: + def get_sms_delivery_config() -> ( + SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] + ): sms_service = sms_delivery.service if sms_delivery is not None else None if sms_service is None: @@ -240,3 +241,38 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ get_sms_delivery_config=get_sms_delivery_config, get_custom_user_input_code=get_custom_user_input_code, ) + + +def get_enabled_pwless_factors( + config: PasswordlessConfig, +) -> List[str]: + all_factors: List[str] = [] + + if config.flow_type == "MAGIC_LINK": + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.LINK_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.LINK_PHONE] + else: + all_factors = [FactorIds.LINK_EMAIL, FactorIds.LINK_PHONE] + elif config.flow_type == "USER_INPUT_CODE": + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.OTP_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.OTP_PHONE] + else: + all_factors = [FactorIds.OTP_EMAIL, FactorIds.OTP_PHONE] + else: + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.OTP_EMAIL, FactorIds.LINK_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.OTP_PHONE, FactorIds.LINK_PHONE] + else: + all_factors = [ + FactorIds.OTP_EMAIL, + FactorIds.OTP_PHONE, + FactorIds.LINK_EMAIL, + FactorIds.LINK_PHONE, + ] + + return all_factors From a630649c4a50810016ba5586cd91dfcfc68f4d4c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 4 Sep 2024 20:07:32 +0530 Subject: [PATCH 029/126] fixes linting errors --- .../dashboard/api/userdetails/user_put.py | 4 +-- supertokens_python/recipe/dashboard/utils.py | 2 -- .../recipe/passwordless/api/implementation.py | 4 +-- .../recipe/passwordless/asyncio/__init__.py | 24 +++++++++++++ .../recipe/passwordless/interfaces.py | 12 +++++++ .../passwordless/recipe_implementation.py | 26 ++++++++++++++ .../recipe/passwordless/syncio/__init__.py | 28 +++++++++++++++ tests/auth-react/django3x/mysite/utils.py | 13 +++++-- tests/auth-react/fastapi-server/app.py | 13 +++++-- tests/auth-react/flask-server/app.py | 13 +++++-- tests/passwordless/test_emaildelivery.py | 5 ++- tests/passwordless/test_mutlitenancy.py | 34 ++++++++++-------- tests/test_passwordless.py | 36 ++++++++++--------- 13 files changed, 168 insertions(+), 46 deletions(-) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index 9770bd313..537e6e9b3 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -104,7 +104,7 @@ async def update_email_for_recipe_id( return UserPutAPIInvalidEmailErrorResponse(validation_error) update_result = await pless_update_user( - user_id, email, user_context=user_context + RecipeUserId(user_id), email, user_context=user_context ) if isinstance(update_result, PlessUpdateUserUnknownUserIdError): @@ -150,7 +150,7 @@ async def update_phone_for_recipe_id( return UserPutAPIInvalidPhoneErrorResponse(validation_error) update_result = await pless_update_user( - user_id, phone_number=phone, user_context=user_context + RecipeUserId(user_id), phone_number=phone, user_context=user_context ) if isinstance(update_result, PlessUpdateUserUnknownUserIdError): diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index de8f138d1..94edb573f 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -187,12 +187,10 @@ def is_valid_recipe_id(recipe_id: str) -> bool: if TYPE_CHECKING: - from supertokens_python.recipe.passwordless.types import User as PasswordlessUser from supertokens_python.recipe.thirdparty.types import User as ThirdPartyUser GetUserResult = Union[ ThirdPartyUser, - PasswordlessUser, None, ] diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 5ece56109..9b0ed20c4 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -86,7 +86,6 @@ def __init__( async def get_passwordless_user_by_account_info( tenant_id: str, - session: Optional[SessionContainer], user_context: Dict[str, Any], account_info: AccountInfo, ) -> Optional[PasswordlessUserResult]: @@ -161,7 +160,7 @@ async def create_code_post( ) user_with_matching_login_method = await get_passwordless_user_by_account_info( - tenant_id, session, user_context, account_info + tenant_id, user_context, account_info ) factor_ids = [] @@ -350,7 +349,6 @@ async def resend_code_post( user_with_matching_login_method = await get_passwordless_user_by_account_info( tenant_id=tenant_id, - session=session, user_context=user_context, account_info=AccountInfo( email=device_info.email, diff --git a/supertokens_python/recipe/passwordless/asyncio/__init__.py b/supertokens_python/recipe/passwordless/asyncio/__init__.py index 8b0a0bdfb..d1f21e71f 100644 --- a/supertokens_python/recipe/passwordless/asyncio/__init__.py +++ b/supertokens_python/recipe/passwordless/asyncio/__init__.py @@ -138,6 +138,30 @@ async def update_user( ) +async def delete_email_for_user( + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} + return await PasswordlessRecipe.get_instance().recipe_implementation.delete_email_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def delete_phone_number_for_user( + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} + return await PasswordlessRecipe.get_instance().recipe_implementation.delete_phone_number_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + async def check_code( tenant_id: str, pre_auth_session_id: str, diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 1470f657c..5247eef4c 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -281,6 +281,18 @@ async def update_user( ]: pass + @abstractmethod + async def delete_email_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + pass + + @abstractmethod + async def delete_phone_number_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + pass + @abstractmethod async def revoke_all_codes( self, diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index b5d030171..6c65a4435 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -437,6 +437,32 @@ async def revoke_code( ) return RevokeCodeOkResult() + async def delete_email_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + data = {"recipeUserId": recipe_user_id.get_as_string(), "email": None} + result = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + data, + user_context=user_context, + ) + if result["status"] == "OK": + return UpdateUserOkResult() + return UpdateUserUnknownUserIdError() + + async def delete_phone_number_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + data = {"recipeUserId": recipe_user_id.get_as_string(), "phoneNumber": None} + result = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + data, + user_context=user_context, + ) + if result["status"] == "OK": + return UpdateUserOkResult() + return UpdateUserUnknownUserIdError() + async def update_user( self, recipe_user_id: RecipeUserId, diff --git a/supertokens_python/recipe/passwordless/syncio/__init__.py b/supertokens_python/recipe/passwordless/syncio/__init__.py index 5618a64d3..7f583a389 100644 --- a/supertokens_python/recipe/passwordless/syncio/__init__.py +++ b/supertokens_python/recipe/passwordless/syncio/__init__.py @@ -145,6 +145,34 @@ def update_user( ) +def delete_email_for_user( + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} + return sync( + asyncio.delete_email_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + ) + + +def delete_phone_number_for_user( + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} + return sync( + asyncio.delete_phone_number_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + ) + + def revoke_all_codes( tenant_id: str, email: Union[str, None] = None, diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index cd37164c7..111f41a8a 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -509,6 +509,7 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -523,6 +524,7 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, tenant_id, api_options, user_context, @@ -531,6 +533,7 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -541,12 +544,13 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, phone_number, session, tenant_id, api_options, user_context ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -557,7 +561,12 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 4513b77c9..a459f7207 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -564,6 +564,7 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -578,6 +579,7 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, tenant_id, api_options, user_context, @@ -586,6 +588,7 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -596,12 +599,13 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, phone_number, session, tenant_id, api_options, user_context ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -612,7 +616,12 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 9b7fc7cca..38bf3972f 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -515,6 +515,7 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -529,6 +530,7 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, tenant_id, api_options, user_context, @@ -537,6 +539,7 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -547,12 +550,13 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, phone_number, session, tenant_id, api_options, user_context ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -563,7 +567,12 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post diff --git a/tests/passwordless/test_emaildelivery.py b/tests/passwordless/test_emaildelivery.py index 41d9418b9..5fcca30d4 100644 --- a/tests/passwordless/test_emaildelivery.py +++ b/tests/passwordless/test_emaildelivery.py @@ -20,7 +20,6 @@ import respx from fastapi import FastAPI from fastapi.requests import Request -from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -182,9 +181,9 @@ async def send_email_override( if not is_version_gte(version, "2.11"): return - pless_response = await signinup("public", "test@example.com", None, {}) + pless_response = await signinup("public", "test@example.com", None, None, {}) create_token = await create_email_verification_token( - "public", RecipeUserId(pless_response.user.user_id) + "public", pless_response.recipe_user_id ) assert isinstance(create_token, CreateEmailVerificationTokenOkResult) diff --git a/tests/passwordless/test_mutlitenancy.py b/tests/passwordless/test_mutlitenancy.py index 623ee25c1..a6e1fed35 100644 --- a/tests/passwordless/test_mutlitenancy.py +++ b/tests/passwordless/test_mutlitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.recipe import session, multitenancy, passwordless from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,11 +21,10 @@ from supertokens_python.recipe.passwordless.asyncio import ( create_code, consume_code, - get_user_by_id, - get_user_by_email, ConsumeCodeOkResult, ) from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -109,28 +109,34 @@ async def test_multitenancy_functions(): assert isinstance(user2, ConsumeCodeOkResult) assert isinstance(user3, ConsumeCodeOkResult) - assert user1.user.user_id != user2.user.user_id - assert user2.user.user_id != user3.user.user_id - assert user3.user.user_id != user1.user.user_id + assert user1.user.id != user2.user.id + assert user2.user.id != user3.user.id + assert user3.user.id != user1.user.id assert user1.user.tenant_ids == ["t1"] assert user2.user.tenant_ids == ["t2"] assert user3.user.tenant_ids == ["t3"] # get user by id: - g_user1 = await get_user_by_id(user1.user.user_id) - g_user2 = await get_user_by_id(user2.user.user_id) - g_user3 = await get_user_by_id(user3.user.user_id) + g_user1 = await get_user(user1.user.id) + g_user2 = await get_user(user2.user.id) + g_user3 = await get_user(user3.user.id) assert g_user1 == user1.user assert g_user2 == user2.user assert g_user3 == user3.user # get user by email: - by_email_user1 = await get_user_by_email("t1", "test@example.com") - by_email_user2 = await get_user_by_email("t2", "test@example.com") - by_email_user3 = await get_user_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) - assert by_email_user1 == user1.user - assert by_email_user2 == user2.user - assert by_email_user3 == user3.user + assert by_email_user1 == [user1.user] + assert by_email_user2 == [user2.user] + assert by_email_user3 == [user3.user] diff --git a/tests/test_passwordless.py b/tests/test_passwordless.py index 309dffdcb..bc376e9f9 100644 --- a/tests/test_passwordless.py +++ b/tests/test_passwordless.py @@ -15,6 +15,12 @@ from typing import Any, Dict from fastapi import FastAPI +from supertokens_python.asyncio import get_user, list_users_by_account_info +from supertokens_python.recipe.passwordless.asyncio import ( + delete_email_for_user, + delete_phone_number_for_user, +) +from supertokens_python.types import AccountInfo, RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, raises, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -22,16 +28,10 @@ from supertokens_python.querier import Querier from supertokens_python.recipe import passwordless, session from supertokens_python.recipe.passwordless.asyncio import ( - delete_email_for_user, - delete_phone_number_for_user, - get_user_by_email, - get_user_by_id, - get_user_by_phone_number, update_user, create_magic_link, ) from supertokens_python.recipe.passwordless.interfaces import ( - DeleteUserInfoOkResult, UpdateUserOkResult, ) from supertokens_python.utils import is_version_gte @@ -227,14 +227,16 @@ async def send_sms( await update_user(user_id, "foo@example.com", "+919494949494") - response = await delete_phone_number_for_user(user_id) - assert isinstance(response, DeleteUserInfoOkResult) + response = await delete_phone_number_for_user(RecipeUserId(user_id)) + assert isinstance(response, UpdateUserOkResult) - user = await get_user_by_phone_number("public", "+919494949494") + user = await list_users_by_account_info( + "public", AccountInfo(phone_number="+919494949494") + ) assert user is None - user = await get_user_by_id(user_id) - assert user is not None and user.phone_number is None + user = await get_user(user_id) + assert user is not None and user.phone_numbers == [] @mark.asyncio @@ -310,14 +312,16 @@ async def send_sms( await update_user(user_id, "hello@example.com", "+919494949494") - response = await delete_email_for_user(user_id) - assert isinstance(response, DeleteUserInfoOkResult) + response = await delete_email_for_user(RecipeUserId(user_id)) + assert isinstance(response, UpdateUserOkResult) - user = await get_user_by_email("public", "hello@example.com") + user = await list_users_by_account_info( + "public", AccountInfo(email="hello@example.com") + ) assert user is None - user = await get_user_by_id(user_id) - assert user is not None and user.email is None + user = await get_user(user_id) + assert user is not None and user.emails == [] @mark.asyncio From 7251e056d59f566f12431de5d5ab6a42a578134f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 6 Sep 2024 15:39:55 +0530 Subject: [PATCH 030/126] third party recipe changes --- supertokens_python/auth_utils.py | 2 +- supertokens_python/recipe/dashboard/utils.py | 9 - .../recipe/passwordless/api/consume_code.py | 24 +- .../recipe/thirdparty/api/implementation.py | 191 +++++++++++--- .../recipe/thirdparty/api/signinup.py | 30 ++- .../recipe/thirdparty/asyncio/__init__.py | 70 ++--- .../recipe/thirdparty/interfaces.py | 91 +++---- .../recipe/thirdparty/recipe.py | 11 +- .../thirdparty/recipe_implementation.py | 239 +++++++++--------- .../recipe/thirdparty/syncio/__init__.py | 63 ++--- supertokens_python/recipe/thirdparty/types.py | 41 +-- supertokens_python/types.py | 8 +- tests/auth-react/django3x/mysite/utils.py | 2 + tests/auth-react/fastapi-server/app.py | 2 + tests/auth-react/flask-server/app.py | 2 + tests/dashboard/test_dashboard.py | 11 +- tests/test_querier.py | 9 +- tests/thirdparty/test_emaildelivery.py | 20 +- tests/thirdparty/test_multitenancy.py | 106 ++++++-- tests/utils.py | 14 +- 20 files changed, 540 insertions(+), 405 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index ae9b02f77..92f8de1f8 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -23,11 +23,11 @@ from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.types import ( AccountInfo, AccountLinkingUser, LoginMethod, - ThirdPartyInfo, ) from supertokens_python.recipe.accountlinking.interfaces import ( RecipeUserId, diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 94edb573f..310b3c098 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -186,15 +186,6 @@ def is_valid_recipe_id(recipe_id: str) -> bool: return recipe_id in ("emailpassword", "thirdparty", "passwordless") -if TYPE_CHECKING: - from supertokens_python.recipe.thirdparty.types import User as ThirdPartyUser - - GetUserResult = Union[ - ThirdPartyUser, - None, - ] - - class GetUserForRecipeIdHelperResult: def __init__( self, user: Optional[AccountLinkingUser] = None, recipe: Optional[str] = None diff --git a/supertokens_python/recipe/passwordless/api/consume_code.py b/supertokens_python/recipe/passwordless/api/consume_code.py index 4da19d9bf..bf2c24681 100644 --- a/supertokens_python/recipe/passwordless/api/consume_code.py +++ b/supertokens_python/recipe/passwordless/api/consume_code.py @@ -13,9 +13,16 @@ # under the License. from typing import Any, Dict from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.passwordless.interfaces import ( + APIInterface, + APIOptions, + ConsumeCodePostOkResult, +) from supertokens_python.recipe.session.asyncio import get_session -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + send_200_response, +) async def consume_code( @@ -76,4 +83,17 @@ async def consume_code( api_options, user_context, ) + + if isinstance(result, ConsumeCodePostOkResult): + return send_200_response( + get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + api_options.response, + ) + return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index 6a49d423d..a1dfd9bba 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -17,26 +17,36 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union from urllib.parse import parse_qs, urlencode, urlparse +from supertokens_python.auth_utils import ( + OkResponse, + PostAuthChecksOkResponse, + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + get_authenticating_user_and_add_to_current_tenant_if_required, + post_auth_checks, + pre_auth_checks, +) +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId from supertokens_python.recipe.emailverification import EmailVerificationRecipe -from supertokens_python.recipe.emailverification.interfaces import ( - CreateEmailVerificationTokenOkResult, -) -from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.recipe.emailverification.asyncio import is_email_verified +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface, AuthorisationUrlGetOkResult, + SignInUpNotAllowed, + SignInUpOkResult, SignInUpPostNoEmailGivenByProviderResponse, SignInUpPostOkResult, ) from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo -from supertokens_python.recipe.thirdparty.types import UserInfoEmail +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo, UserInfoEmail if TYPE_CHECKING: from supertokens_python.recipe.thirdparty.interfaces import APIOptions from supertokens_python.recipe.thirdparty.provider import Provider -from supertokens_python.types import GeneralErrorResponse, RecipeUserId +from supertokens_python.types import GeneralErrorResponse class APIImplementation(APIInterface): @@ -62,14 +72,27 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ SignInUpPostOkResult, SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, GeneralErrorResponse, ]: + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_006)", + "SIGN_IN_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_004)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_020)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_021)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_022)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_023)", + }, + } + oauth_tokens_to_use: Dict[str, Any] = {} if redirect_uri_info is not None: @@ -77,8 +100,10 @@ async def sign_in_up_post( redirect_uri_info=redirect_uri_info, user_context=user_context, ) + elif oauth_tokens is not None: + oauth_tokens_to_use = oauth_tokens else: - oauth_tokens_to_use = oauth_tokens # type: ignore + raise Exception("should never come here") user_info = await provider.get_user_info( oauth_tokens=oauth_tokens_to_use, @@ -88,61 +113,145 @@ async def sign_in_up_post( if user_info.email is None and provider.config.require_email is False: # We don't expect to get an email from this provider. # So we generate a fake one - if provider.config.generate_fake_email is not None: - user_info.email = UserInfoEmail( - email=await provider.config.generate_fake_email( - tenant_id, user_info.third_party_user_id, user_context - ), - is_verified=True, + assert provider.config.generate_fake_email is not None + user_info.email = UserInfoEmail( + email=await provider.config.generate_fake_email( + tenant_id, user_info.third_party_user_id, user_context + ), + is_verified=True, + ) + + email_info = user_info.email + if email_info is None: + return SignInUpPostNoEmailGivenByProviderResponse() + + recipe_id = "thirdparty" + + async def check_credentials_on_tenant(_: str): + # We essentially did this above when calling exchange_auth_code_for_oauth_tokens + return True + + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + third_party=ThirdPartyInfo( + third_party_user_id=user_info.third_party_user_id, + third_party_id=provider.id, + ), + email=None, + phone_number=None, + recipe_id=recipe_id, + user_context=user_context, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials_on_tenant, + ) + ) + + is_sign_up = authenticating_user is None + if authenticating_user is not None: + # This is a sign in. So before we proceed, we need to check if an email change + # is allowed since the email could have changed from the social provider's side. + # We do this check here and not in the recipe function cause we want to keep the + # recipe function checks to a minimum so that the dev has complete control of + # what they can do. + + # The is_email_change_allowed and pre_auth_checks functions take an is_verified boolean. + # Now, even though we already have that from the input, that's just what the provider says. + # If the provider says that the email is NOT verified, it could have been that the email + # is verified on the user's account via supertokens on a previous sign in / up. + # So we just check that as well before calling is_email_change_allowed + + assert authenticating_user.login_method is not None + recipe_user_id = authenticating_user.login_method.recipe_user_id + + if ( + not email_info.is_verified + and EmailVerificationRecipe.get_instance_optional() is not None + ): + email_info.is_verified = await is_email_verified( + recipe_user_id, + email_info.id, + user_context, ) - email = user_info.email.id if user_info.email is not None else None - email_verified = ( - user_info.email.is_verified if user_info.email is not None else None + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id=recipe_id, + email=email_info.id, + third_party=ThirdPartyInfo( + third_party_user_id=user_info.third_party_user_id, + third_party_id=provider.id, + ), + ), + authenticating_user=( + authenticating_user.user if authenticating_user else None + ), + factor_ids=["thirdparty"], + is_sign_up=is_sign_up, + is_verified=email_info.is_verified, + sign_in_verifies_login_method=email_info.is_verified, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + user_context=user_context, + session=session, ) - if email is None: - return SignInUpPostNoEmailGivenByProviderResponse() + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpNotAllowed(reason=reason) signinup_response = await api_options.recipe_implementation.sign_in_up( third_party_id=provider.id, third_party_user_id=user_info.third_party_user_id, - email=email, + email=email_info.id, + is_verified=email_info.is_verified, oauth_tokens=oauth_tokens_to_use, raw_user_info_from_provider=user_info.raw_user_info_from_provider, + session=session, tenant_id=tenant_id, user_context=user_context, ) - if email_verified: - ev_instance = EmailVerificationRecipe.get_instance_optional() - if ev_instance is not None: - token_response = await ev_instance.recipe_implementation.create_email_verification_token( - tenant_id=tenant_id, - recipe_user_id=RecipeUserId(signinup_response.user.user_id), - email=signinup_response.user.email, - user_context=user_context, - ) + if isinstance(signinup_response, SignInUpNotAllowed): + return signinup_response - if isinstance(token_response, CreateEmailVerificationTokenOkResult): - await ev_instance.recipe_implementation.verify_email_using_token( - token=token_response.token, - tenant_id=tenant_id, - attempt_account_linking=True, - user_context=user_context, - ) + if not isinstance(signinup_response, SignInUpOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[signinup_response.reason] + return SignInUpNotAllowed(reason=reason) - user = signinup_response.user - session = await create_new_session( + post_auth_checks_result = await post_auth_checks( + factor_id="thirdparty", + is_sign_up=is_sign_up, + authenticated_user=signinup_response.user, + recipe_user_id=signinup_response.recipe_user_id, request=api_options.request, tenant_id=tenant_id, - recipe_user_id=RecipeUserId(user.user_id), user_context=user_context, + session=session, ) + if not isinstance(post_auth_checks_result, PostAuthChecksOkResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + return SignInUpPostOkResult( - created_new_user=signinup_response.created_new_user, - user=user, - session=session, + created_new_recipe_user=signinup_response.created_new_recipe_user, + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, oauth_tokens=oauth_tokens_to_use, raw_user_info_from_provider=user_info.raw_user_info_from_provider, ) diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index 60937ea49..af94ccdc6 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -14,13 +14,18 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.recipe.thirdparty.interfaces import SignInUpPostOkResult from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo if TYPE_CHECKING: from supertokens_python.recipe.thirdparty.interfaces import APIOptions, APIInterface from supertokens_python.exceptions import raise_bad_input_exception, BadInputError -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + send_200_response, +) async def handle_sign_in_up_api( @@ -80,6 +85,15 @@ async def handle_sign_in_up_api( pkce_code_verifier=redirect_uri_info.get("pkceCodeVerifier"), ) + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.sign_in_up_post( provider=provider, redirect_uri_info=redirect_uri_info, @@ -87,5 +101,19 @@ async def handle_sign_in_up_api( tenant_id=tenant_id, api_options=api_options, user_context=user_context, + session=session, ) + + if isinstance(result, SignInUpPostOkResult): + return send_200_response( + get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + api_options.response, + ) + return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/thirdparty/asyncio/__init__.py b/supertokens_python/recipe/thirdparty/asyncio/__init__.py index 0e70c5166..1568362ae 100644 --- a/supertokens_python/recipe/thirdparty/asyncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/asyncio/__init__.py @@ -12,66 +12,42 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe -from ..types import User - - -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await ThirdPartyRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id, user_context - ) - - -async def get_users_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> List[User]: - if user_context is None: - user_context = {} - return ( - await ThirdPartyRecipe.get_instance().recipe_implementation.get_users_by_email( - email, tenant_id, user_context - ) - ) - - -async def get_user_by_third_party_info( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - return await ThirdPartyRecipe.get_instance().recipe_implementation.get_user_by_thirdparty_info( - third_party_id, - third_party_user_id, - tenant_id, - user_context, - ) - async def manually_create_or_update_user( tenant_id: str, third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer], user_context: Union[None, Dict[str, Any]] = None, -): +) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, +]: if user_context is None: user_context = {} return await ThirdPartyRecipe.get_instance().recipe_implementation.manually_create_or_update_user( - third_party_id, - third_party_user_id, - email, - tenant_id, - user_context, + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + email=email, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, ) diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index b4c5c6bee..1d106a79b 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -16,7 +16,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from ...types import APIResponse, GeneralErrorResponse +from supertokens_python.auth_utils import LinkingToSessionUserFailedError + +from ...types import APIResponse, AccountLinkingUser, GeneralErrorResponse, RecipeUserId from .provider import Provider, ProviderInput, RedirectUriInfo if TYPE_CHECKING: @@ -24,32 +26,36 @@ from supertokens_python.recipe.session import SessionContainer from supertokens_python.supertokens import AppInfo - from .types import User, RawUserInfoFromProvider + from .types import RawUserInfoFromProvider from .utils import ThirdPartyConfig class SignInUpOkResult: def __init__( self, - user: User, - created_new_user: bool, + user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + created_new_recipe_user: bool, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, ): self.user = user - self.created_new_user = created_new_user + self.created_new_recipe_user = created_new_recipe_user self.oauth_tokens = oauth_tokens self.raw_user_info_from_provider = raw_user_info_from_provider + self.recipe_user_id = recipe_user_id class ManuallyCreateOrUpdateUserOkResult: def __init__( self, - user: User, - created_new_user: bool, + user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + created_new_recipe_user: bool, ): self.user = user - self.created_new_user = created_new_user + self.recipe_user_id = recipe_user_id + self.created_new_recipe_user = created_new_recipe_user class GetProviderOkResult: @@ -57,30 +63,24 @@ def __init__(self, provider: Provider): self.provider = provider -class RecipeInterface(ABC): - def __init__(self): - pass +class SignInUpNotAllowed(APIResponse): + status: str = "SIGN_IN_UP_NOT_ALLOWED" + reason: str - @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass + def __init__(self, reason: str): + self.reason = reason - @abstractmethod - async def get_users_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> List[User]: - pass + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "reason": self.reason} - @abstractmethod - async def get_user_by_thirdparty_info( - self, - third_party_id: str, - third_party_user_id: str, - tenant_id: str, - user_context: Dict[str, Any], - ) -> Union[User, None]: + +class EmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + +class RecipeInterface(ABC): + def __init__(self): pass @abstractmethod @@ -89,9 +89,16 @@ async def manually_create_or_update_user( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], - ) -> ManuallyCreateOrUpdateUserOkResult: + ) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, + ]: pass @abstractmethod @@ -100,11 +107,13 @@ async def sign_in_up( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], - ) -> SignInUpOkResult: + ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: pass @abstractmethod @@ -143,14 +152,14 @@ class SignInUpPostOkResult(APIResponse): def __init__( self, - user: User, - created_new_user: bool, + user: AccountLinkingUser, + created_new_recipe_user: bool, session: SessionContainer, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, ): self.user = user - self.created_new_user = created_new_user + self.created_new_recipe_user = created_new_recipe_user self.session = session self.oauth_tokens = oauth_tokens self.raw_user_info_from_provider = raw_user_info_from_provider @@ -158,16 +167,8 @@ def __init__( def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - "thirdParty": { - "id": self.user.third_party_info.id, - "userId": self.user.third_party_info.user_id, - }, - }, - "createdNewUser": self.created_new_user, + "user": self.user.to_json(), + "createdNewRecipeUser": self.created_new_recipe_user, } @@ -217,12 +218,14 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ SignInUpPostOkResult, SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, GeneralErrorResponse, ]: pass diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index f9f50af0d..9a7ca219f 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -23,7 +23,6 @@ from .api.implementation import APIImplementation from .interfaces import APIInterface, APIOptions, RecipeInterface from .recipe_implementation import RecipeImplementation -from ..emailverification.interfaces import GetEmailForUserIdOkResult, UnknownUserIdError from ...post_init_callbacks import PostSTInitCallbacks if TYPE_CHECKING: @@ -83,6 +82,7 @@ def callback(): mt_recipe = MultitenancyRecipe.get_instance_optional() if mt_recipe: mt_recipe.static_third_party_providers = self.providers + mt_recipe.all_available_first_factors.append("thirdparty") PostSTInitCallbacks.add_post_init_callback(callback) @@ -200,12 +200,3 @@ def reset(): ThirdPartyRecipe.__instance = None # instance functions below............... - - async def get_email_for_user_id(self, user_id: str, user_context: Dict[str, Any]): - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - return GetEmailForUserIdOkResult(user_info.email) - - return UnknownUserIdError() diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index 0647212ea..de2db33e0 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -14,25 +14,35 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user, list_users_by_account_info +from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, +) from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.provider import ProviderInput from supertokens_python.recipe.thirdparty.providers.config_utils import ( find_and_create_provider_instance, merge_providers_from_core_and_static, ) +from supertokens_python.types import AccountInfo, AccountLinkingUser, RecipeUserId if TYPE_CHECKING: from supertokens_python.querier import Querier from .interfaces import ( + EmailChangeNotAllowedError, ManuallyCreateOrUpdateUserOkResult, RecipeInterface, + SignInUpNotAllowed, SignInUpOkResult, ) -from .types import RawUserInfoFromProvider, ThirdPartyInfo, User +from .types import RawUserInfoFromProvider, ThirdPartyInfo class RecipeImplementation(RecipeInterface): @@ -41,149 +51,138 @@ def __init__(self, querier: Querier, providers: List[ProviderInput]): self.querier = querier self.providers = providers - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"userId": user_id} - response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), - params, - user_context=user_context, - ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ) - return None - - async def get_users_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> List[User]: - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/users/by-email"), - {"email": email}, - user_context=user_context, - ) - users: List[User] = [] - users_list: List[Dict[str, Any]] = ( - response["users"] if "users" in response else [] - ) - for user in users_list: - users.append( - User( - user["id"], - user["email"], - user["timeJoined"], - user["tenantIds"], - ThirdPartyInfo( - user["thirdParty"]["userId"], user["thirdParty"]["id"] - ), - ) - ) - return users - - async def get_user_by_thirdparty_info( - self, - third_party_id: str, - third_party_user_id: str, - tenant_id: str, - user_context: Dict[str, Any], - ) -> Union[User, None]: - params = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - } - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - params, - user_context=user_context, - ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ) - return None - async def sign_in_up( self, third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], - ) -> SignInUpOkResult: - data = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - "email": {"id": email}, - } - response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup"), - data, + ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: + response = await self.manually_create_or_update_user( + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + email=email, + tenant_id=tenant_id, + is_verified=is_verified, + session=session, user_context=user_context, ) - return SignInUpOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ), - response["createdNewUser"], - oauth_tokens, - raw_user_info_from_provider, - ) + + if isinstance(response, EmailChangeNotAllowedError): + return SignInUpNotAllowed( + "Cannot sign in / up because new email cannot be applied to existing account. Please contact support. (ERR_CODE_005)" + if response.reason + == "Email already associated with another primary user." + else "Cannot sign in / up because new email cannot be applied to existing account. Please contact support. (ERR_CODE_024)" + ) + + if isinstance(response, ManuallyCreateOrUpdateUserOkResult): + return SignInUpOkResult( + user=response.user, + recipe_user_id=response.recipe_user_id, + created_new_recipe_user=response.created_new_recipe_user, + oauth_tokens=oauth_tokens, + raw_user_info_from_provider=raw_user_info_from_provider, + ) + + return response async def manually_create_or_update_user( self, third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], - ) -> ManuallyCreateOrUpdateUserOkResult: - data = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - "email": {"id": email}, - } + ) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, + ]: + account_linking = AccountLinkingRecipe.get_instance() + users = await list_users_by_account_info( + tenant_id, + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + ), + ), + False, + user_context, + ) + + user = users[0] if users else None + if user is not None: + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_verified, + new_email=email, + session=session, + user_context=user_context, + ) + if not is_email_change_allowed.allowed: + reason = ( + "Email already associated with another primary user." + if is_email_change_allowed.reason == "PRIMARY_USER_CONFLICT" + else "New email cannot be applied to existing account because of account takeover risks." + ) + return EmailChangeNotAllowedError(reason) + response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup"), - data, + { + "thirdPartyId": third_party_id, + "thirdPartyUserId": third_party_user_id, + "email": {"id": email, "isVerified": is_verified}, + }, + user_context=user_context, + ) + + if response["status"] == "EMAIL_CHANGE_NOT_ALLOWED_ERROR": + return EmailChangeNotAllowedError(response["reason"]) + + # status is OK + + user = AccountLinkingUser.from_json( + response["user"], + ) + recipe_user_id = RecipeUserId(response["recipeUserId"]) + + await account_linking.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + # Fetch updated user + user = await get_user(recipe_user_id.get_as_string(), user_context) + + assert user is not None + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=user, + recipe_user_id=recipe_user_id, + session=session, user_context=user_context, ) + + if link_result.status != "OK": + return link_result + return ManuallyCreateOrUpdateUserOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ), - response["createdNewUser"], + user=link_result.user, + recipe_user_id=recipe_user_id, + created_new_recipe_user=response["createdNewUser"], ) async def get_provider( diff --git a/supertokens_python/recipe/thirdparty/syncio/__init__.py b/supertokens_python/recipe/thirdparty/syncio/__init__.py index 218c6862e..429b4b027 100644 --- a/supertokens_python/recipe/thirdparty/syncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/syncio/__init__.py @@ -11,46 +11,16 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union from supertokens_python.async_to_sync_wrapper import sync - -from ..types import User - - -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - from supertokens_python.recipe.thirdparty.asyncio import get_user_by_id - - return sync(get_user_by_id(user_id, user_context)) - - -def get_users_by_email( - tenant_id: str, - email: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> List[User]: - from supertokens_python.recipe.thirdparty.asyncio import get_users_by_email - - return sync(get_users_by_email(tenant_id, email, user_context)) - - -def get_user_by_third_party_info( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_third_party_info, - ) - - return sync( - get_user_by_third_party_info( - tenant_id, third_party_id, third_party_user_id, user_context - ) - ) +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) def manually_create_or_update_user( @@ -58,15 +28,28 @@ def manually_create_or_update_user( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer], user_context: Union[None, Dict[str, Any]] = None, -): +) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, +]: from supertokens_python.recipe.thirdparty.asyncio import ( manually_create_or_update_user, ) return sync( manually_create_or_update_user( - tenant_id, third_party_id, third_party_user_id, email, user_context + email=email, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + user_context=user_context, ) ) diff --git a/supertokens_python/recipe/thirdparty/types.py b/supertokens_python/recipe/thirdparty/types.py index 5d18c6b3a..8b2be0d66 100644 --- a/supertokens_python/recipe/thirdparty/types.py +++ b/supertokens_python/recipe/thirdparty/types.py @@ -11,10 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Union, Optional +from __future__ import annotations + +from typing import Any, Callable, Dict, Union, Optional, TYPE_CHECKING from supertokens_python.framework.request import BaseRequest +if TYPE_CHECKING: + from supertokens_python.types import AccountLinkingUser + class ThirdPartyInfo: def __init__(self, third_party_user_id: str, third_party_id: str): @@ -39,32 +44,6 @@ def __init__( self.from_user_info_api = from_user_info_api -class User: - def __init__( - self, - user_id: str, - email: str, - time_joined: int, - tenant_ids: List[str], - third_party_info: ThirdPartyInfo, - ): - self.user_id: str = user_id - self.email: str = email - self.time_joined: int = time_joined - self.tenant_ids = tenant_ids - self.third_party_info: ThirdPartyInfo = third_party_info - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - and self.third_party_info == other.third_party_info - ) - - class UserInfoEmail: def __init__(self, email: str, is_verified: bool): self.id: str = email @@ -100,16 +79,10 @@ def __init__( class SignInUpResponse: - def __init__(self, user: User, is_new_user: bool): + def __init__(self, user: AccountLinkingUser, is_new_user: bool): self.user = user self.is_new_user = is_new_user -class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users = users - self.next_pagination_token = next_pagination_token - - class ThirdPartyIngredients: pass diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 4a753e885..14ef4e4a4 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -18,6 +18,8 @@ import phonenumbers # type: ignore from typing_extensions import Literal +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo + _T = TypeVar("_T") @@ -34,12 +36,6 @@ def __eq__(self, other: Any) -> bool: return False -class ThirdPartyInfo: - def __init__(self, third_party_id: str, third_party_user_id: str): - self.id = third_party_id - self.user_id = third_party_user_id - - class AccountInfo: def __init__( self, diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 111f41a8a..46ba367a2 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -442,6 +442,7 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -455,6 +456,7 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, tenant_id, api_options, user_context, diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index a459f7207..bc363b69d 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -497,6 +497,7 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -510,6 +511,7 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, tenant_id, api_options, user_context, diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 38bf3972f..fe711cf53 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -448,6 +448,7 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -461,6 +462,7 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, tenant_id, api_options, user_context, diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index dc8320332..f4392d449 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -2,6 +2,9 @@ from fastapi import FastAPI from pytest import fixture, mark +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, +) from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import init from supertokens_python.constants import DASHBOARD_VERSION @@ -248,9 +251,11 @@ async def should_allow_access( start_st() pluser = await manually_create_or_update_user( - "public", "google", "googleid", "test@example.com" + "public", "google", "googleid", "test@example.com", True, None ) + assert isinstance(pluser, ManuallyCreateOrUpdateUserOkResult) + res = app.get( url="/auth/dashboard/api/user", params={ @@ -264,10 +269,10 @@ async def should_allow_access( res = app.get( url="/auth/dashboard/api/user", params={ - "userId": pluser.user.user_id, + "userId": pluser.user.id, "recipeId": "thirdparty", }, ) res_json = res.json() assert res_json["status"] == "OK" - assert res_json["user"]["id"] == pluser.user.user_id + assert res_json["user"]["id"] == pluser.user.id diff --git a/tests/test_querier.py b/tests/test_querier.py index b6722a019..7ef9323c8 100644 --- a/tests/test_querier.py +++ b/tests/test_querier.py @@ -21,9 +21,6 @@ ) from supertokens_python import InputAppInfo from supertokens_python.recipe.emailpassword.asyncio import get_user, sign_up -from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_id as tp_get_user, -) import asyncio import respx import httpx @@ -247,14 +244,14 @@ def intercept( assert user is None assert not called_core - user = await tp_get_user("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await tp_get_user("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core @@ -420,7 +417,7 @@ def intercept( called_core = False - user = await tp_get_user("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core diff --git a/tests/thirdparty/test_emaildelivery.py b/tests/thirdparty/test_emaildelivery.py index d42fd65c5..dd33e2bdf 100644 --- a/tests/thirdparty/test_emaildelivery.py +++ b/tests/thirdparty/test_emaildelivery.py @@ -140,14 +140,14 @@ async def test_email_verify_default_backward_compatibility( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", True, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) @@ -214,14 +214,14 @@ async def test_email_verify_default_backward_compatibility_supress_error( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", True, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) @@ -304,14 +304,14 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", True, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) @@ -382,14 +382,14 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", True, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id assert isinstance(user_id, str) response = await create_new_session( s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None @@ -522,14 +522,14 @@ async def send_email_override( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", True, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id assert isinstance(user_id, str) response = await create_new_session( s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None diff --git a/tests/thirdparty/test_multitenancy.py b/tests/thirdparty/test_multitenancy.py index d507fd409..113ba9b38 100644 --- a/tests/thirdparty/test_multitenancy.py +++ b/tests/thirdparty/test_multitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user from supertokens_python.recipe import session, multitenancy, thirdparty from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,12 +21,15 @@ ) from supertokens_python.recipe.thirdparty.asyncio import ( manually_create_or_update_user, - get_user_by_id, - get_users_by_email, - get_user_by_third_party_info, get_provider, ) from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, +) +from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -55,23 +59,29 @@ async def test_thirtyparty_multitenancy_functions(): # sign up: user1a = await manually_create_or_update_user( - "t1", "google", "googleid1", "test@example.com" + "t1", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user1a, ManuallyCreateOrUpdateUserOkResult) user1b = await manually_create_or_update_user( - "t1", "facebook", "fbid1", "test@example.com" + "t1", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user1b, ManuallyCreateOrUpdateUserOkResult) user2a = await manually_create_or_update_user( - "t2", "google", "googleid1", "test@example.com" + "t2", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user2a, ManuallyCreateOrUpdateUserOkResult) user2b = await manually_create_or_update_user( - "t2", "facebook", "fbid1", "test@example.com" + "t2", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user2b, ManuallyCreateOrUpdateUserOkResult) user3a = await manually_create_or_update_user( - "t3", "google", "googleid1", "test@example.com" + "t3", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user3a, ManuallyCreateOrUpdateUserOkResult) user3b = await manually_create_or_update_user( - "t3", "facebook", "fbid1", "test@example.com" + "t3", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user3b, ManuallyCreateOrUpdateUserOkResult) assert user1a.user.tenant_ids == ["t1"] assert user1b.user.tenant_ids == ["t1"] @@ -81,12 +91,12 @@ async def test_thirtyparty_multitenancy_functions(): assert user3b.user.tenant_ids == ["t3"] # get user by id: - g_user1a = await get_user_by_id(user1a.user.user_id) - g_user1b = await get_user_by_id(user1b.user.user_id) - g_user2a = await get_user_by_id(user2a.user.user_id) - g_user2b = await get_user_by_id(user2b.user.user_id) - g_user3a = await get_user_by_id(user3a.user.user_id) - g_user3b = await get_user_by_id(user3b.user.user_id) + g_user1a = await get_user(user1a.user.id) + g_user1b = await get_user(user1b.user.id) + g_user2a = await get_user(user2a.user.id) + g_user2b = await get_user(user2b.user.id) + g_user3a = await get_user(user3a.user.id) + g_user3b = await get_user(user3b.user.id) assert g_user1a == user1a.user assert g_user1b == user1b.user @@ -96,21 +106,69 @@ async def test_thirtyparty_multitenancy_functions(): assert g_user3b == user3b.user # get user by email: - by_email_user1 = await get_users_by_email("t1", "test@example.com") - by_email_user2 = await get_users_by_email("t2", "test@example.com") - by_email_user3 = await get_users_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) assert by_email_user1 == [user1a.user, user1b.user] assert by_email_user2 == [user2a.user, user2b.user] assert by_email_user3 == [user3a.user, user3b.user] # get user by thirdparty id: - g_user_by_tpid1a = await get_user_by_third_party_info("t1", "google", "googleid1") - g_user_by_tpid1b = await get_user_by_third_party_info("t1", "facebook", "fbid1") - g_user_by_tpid2a = await get_user_by_third_party_info("t2", "google", "googleid1") - g_user_by_tpid2b = await get_user_by_third_party_info("t2", "facebook", "fbid1") - g_user_by_tpid3a = await get_user_by_third_party_info("t3", "google", "googleid1") - g_user_by_tpid3b = await get_user_by_third_party_info("t3", "facebook", "fbid1") + g_user_by_tpid1a = await list_users_by_account_info( + "t1", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid1b = await list_users_by_account_info( + "t1", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) + g_user_by_tpid2a = await list_users_by_account_info( + "t2", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid2b = await list_users_by_account_info( + "t2", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) + g_user_by_tpid3a = await list_users_by_account_info( + "t3", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid3b = await list_users_by_account_info( + "t3", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) assert g_user_by_tpid1a == user1a.user assert g_user_by_tpid1b == user1b.user diff --git a/tests/utils.py b/tests/utils.py index cd1b2b2fa..bc4386cfa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -265,12 +265,12 @@ def extract_info(response: Response) -> Dict[str, Any]: "antiCsrf": response.headers.get("anti-csrf"), "accessTokenFromHeader": access_token_from_header, "refreshTokenFromHeader": refresh_token_from_header, - "accessTokenFromAny": access_token_from_header - if access_token is None - else access_token, - "refreshTokenFromAny": refresh_token_from_header - if refresh_token is None - else refresh_token, + "accessTokenFromAny": ( + access_token_from_header if access_token is None else access_token + ), + "refreshTokenFromAny": ( + refresh_token_from_header if refresh_token is None else refresh_token + ), } @@ -599,5 +599,5 @@ async def create_users( ) elif user["recipe"] == "thirdparty" and thirdparty: await manually_create_or_update_user( - "public", user["provider"], user["userId"], user["email"] + "public", user["provider"], user["userId"], user["email"], True, None ) From 7af26f72ac1ffe89d85f37eff7578f57e6d2c998 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 13 Sep 2024 14:06:10 +0530 Subject: [PATCH 031/126] third party provider updates --- .../thirdparty/providers/active_directory.py | 13 +++--- .../recipe/thirdparty/providers/apple.py | 46 ++++++++++++++++++- .../recipe/thirdparty/providers/facebook.py | 45 +++++++++++++++++- .../recipe/thirdparty/providers/gitlab.py | 7 +-- .../recipe/thirdparty/providers/okta.py | 22 ++++----- 5 files changed, 107 insertions(+), 26 deletions(-) diff --git a/supertokens_python/recipe/thirdparty/providers/active_directory.py b/supertokens_python/recipe/thirdparty/providers/active_directory.py index 14e20e827..2a4476a90 100644 --- a/supertokens_python/recipe/thirdparty/providers/active_directory.py +++ b/supertokens_python/recipe/thirdparty/providers/active_directory.py @@ -37,13 +37,12 @@ async def get_config_for_client_type( raise Exception( "Please provide the directoryId in the additionalConfig of the Active Directory provider." ) - else: - config.oidc_discovery_endpoint = f"https://login.microsoftonline.com/{config.additional_config['directoryId']}/v2.0/.well-known/openid-configuration" - - # The config could be coming from core where we didn't add the well-known previously - config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( - config.oidc_discovery_endpoint - ) + if config.oidc_discovery_endpoint is not None: + config.oidc_discovery_endpoint = ( + normalise_oidc_endpoint_to_include_well_known( + config.oidc_discovery_endpoint + ) + ) if config.scope is None: config.scope = ["openid", "email"] diff --git a/supertokens_python/recipe/thirdparty/providers/apple.py b/supertokens_python/recipe/thirdparty/providers/apple.py index 93aa7ebac..3bba95ef8 100644 --- a/supertokens_python/recipe/thirdparty/providers/apple.py +++ b/supertokens_python/recipe/thirdparty/providers/apple.py @@ -12,14 +12,17 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import json from re import sub from typing import Any, Dict, Optional from jwt import encode # type: ignore from time import time +from supertokens_python.recipe.thirdparty.types import UserInfo + from .custom import GenericProvider, NewProvider -from ..provider import Provider, ProviderConfigForClient, ProviderInput +from ..provider import Provider, ProviderConfigForClient, ProviderInput, RedirectUriInfo from .utils import ( get_actual_client_id_from_development_client_id, normalise_oidc_endpoint_to_include_well_known, @@ -76,6 +79,47 @@ async def _get_client_secret( # pylint: disable=no-self-use headers=headers, ) # type: ignore + async def exchange_auth_code_for_oauth_tokens( + self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any] + ) -> Dict[str, Any]: + response = await super().exchange_auth_code_for_oauth_tokens( + redirect_uri_info, user_context + ) + + user = redirect_uri_info.redirect_uri_query_params.get("user") + if user is not None: + if isinstance(user, str): + response["user"] = json.loads(user) + elif isinstance(user, dict): + response["user"] = user + + return response + + async def get_user_info( + self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] + ) -> UserInfo: + response = await super().get_user_info(oauth_tokens, user_context) + user = oauth_tokens.get("user") + + user_dict: Dict[str, Any] = {} + + if user is not None: + if isinstance(user, str): + user_dict = json.loads(user) + elif isinstance(user, dict): + user_dict = user + else: + return response + + if response.raw_user_info_from_provider.from_id_token_payload is None: + response.raw_user_info_from_provider.from_id_token_payload = {} + + response.raw_user_info_from_provider.from_id_token_payload[ + "user" + ] = user_dict + + return response + def Apple(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin if not input.config.name: diff --git a/supertokens_python/recipe/thirdparty/providers/facebook.py b/supertokens_python/recipe/thirdparty/providers/facebook.py index d6ac1a890..327be5181 100644 --- a/supertokens_python/recipe/thirdparty/providers/facebook.py +++ b/supertokens_python/recipe/thirdparty/providers/facebook.py @@ -44,12 +44,55 @@ async def get_config_for_client_type( async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] ) -> UserInfo: + fields_permission_map = { + "public_profile": [ + "first_name", + "last_name", + "middle_name", + "name", + "name_format", + "picture", + "short_name", + ], + "email": ["id", "email"], + "user_birthday": ["birthday"], + "user_videos": ["videos"], + "user_posts": ["posts"], + "user_photos": ["photos"], + "user_location": ["location"], + "user_link": ["link"], + "user_likes": ["likes"], + "user_hometown": ["hometown"], + "user_gender": ["gender"], + "user_friends": ["friends"], + "user_age_range": ["age_range"], + } + scope_values = self.config.scope + + fields = ( + ",".join( + [ + field + for scope in scope_values + for field in fields_permission_map.get(scope, []) + ] + ) + if scope_values + else "id,email" + ) + self.config.user_info_endpoint_query_params = { "access_token": str(oauth_tokens["access_token"]), - "fields": "id,email", + "fields": fields, "format": "json", **(self.config.user_info_endpoint_query_params or {}), } + + self.config.user_info_endpoint_headers = { + **(self.config.user_info_endpoint_headers or {}), + "Authorization": None, + } + return await super().get_user_info(oauth_tokens, user_context) diff --git a/supertokens_python/recipe/thirdparty/providers/gitlab.py b/supertokens_python/recipe/thirdparty/providers/gitlab.py index de768067c..bdbcac750 100644 --- a/supertokens_python/recipe/thirdparty/providers/gitlab.py +++ b/supertokens_python/recipe/thirdparty/providers/gitlab.py @@ -47,9 +47,10 @@ async def get_config_for_client_type( oidc_domain.get_as_string_dangerous() + oidc_path.get_as_string_dangerous() ) - - if not config.oidc_discovery_endpoint: - raise Exception("should never come here") + elif config.oidc_discovery_endpoint is None: + config.oidc_discovery_endpoint = ( + "https://gitlab.com/.well-known/openid-configuration" + ) # The config could be coming from core where we didn't add the well-known previously config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( diff --git a/supertokens_python/recipe/thirdparty/providers/okta.py b/supertokens_python/recipe/thirdparty/providers/okta.py index 9950afe39..56e180029 100644 --- a/supertokens_python/recipe/thirdparty/providers/okta.py +++ b/supertokens_python/recipe/thirdparty/providers/okta.py @@ -32,14 +32,9 @@ async def get_config_for_client_type( config = await super().get_config_for_client_type(client_type, user_context) if ( - config.additional_config is None - or config.additional_config.get("oktaDomain") is None + config.additional_config is not None + and config.additional_config.get("oktaDomain") is not None ): - if not config.oidc_discovery_endpoint: - raise Exception( - "Please provide the oktaDomain in the additionalConfig of the Okta provider." - ) - else: okta_domain = config.additional_config["oktaDomain"] oidc_domain = NormalisedURLDomain(okta_domain) oidc_path = NormalisedURLPath("/.well-known/openid-configuration") @@ -48,13 +43,12 @@ async def get_config_for_client_type( + oidc_path.get_as_string_dangerous() ) - if not config.oidc_discovery_endpoint: - raise Exception("should never happen") - - # The config could be coming from core where we didn't add the well-known previously - config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( - config.oidc_discovery_endpoint - ) + if config.oidc_discovery_endpoint is not None: + config.oidc_discovery_endpoint = ( + normalise_oidc_endpoint_to_include_well_known( + config.oidc_discovery_endpoint + ) + ) if config.scope is None: config.scope = ["openid", "email"] From c69dfeadad419d0114295146a9fb71277cbb1124 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 12:32:40 +0530 Subject: [PATCH 032/126] adds 2 dashboard apis --- .../create_or_update_third_party_config.py | 161 ++++++++++++++++++ .../api/multitenancy/create_tenant.py | 97 +++++++++++ .../recipe/dashboard/constants.py | 2 + supertokens_python/recipe/dashboard/recipe.py | 28 +++ .../recipe/multitenancy/asyncio/__init__.py | 3 +- .../recipe/multitenancy/interfaces.py | 36 +++- .../multitenancy/recipe_implementation.py | 3 +- .../recipe/multitenancy/syncio/__init__.py | 5 +- .../recipe/thirdparty/provider.py | 60 ++++++- .../recipe/thirdparty/providers/utils.py | 7 +- supertokens_python/utils.py | 7 + tests/emailpassword/test_multitenancy.py | 16 +- tests/multitenancy/test_tenants_crud.py | 36 ++-- tests/passwordless/test_mutlitenancy.py | 10 +- tests/test-server/multitenancy.py | 5 +- tests/thirdparty/test_multitenancy.py | 28 ++- tests/userroles/test_multitenancy.py | 16 +- 17 files changed, 478 insertions(+), 42 deletions(-) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py new file mode 100644 index 000000000..32f403a82 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.multitenancy.asyncio import ( + get_tenant, + create_or_update_third_party_config, +) +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.recipe.thirdparty.providers.utils import do_post_request +from supertokens_python.types import APIResponse +from supertokens_python.utils import encode_base64 +from ...interfaces import APIInterface, APIOptions +import asyncio +import json + + +class CreateOrUpdateThirdPartyConfigOkResult(APIResponse): + def __init__(self, created_new: bool): + self.status = "OK" + self.created_new = created_new + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "createdNew": self.created_new} + + +class CreateOrUpdateThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateOrUpdateThirdPartyConfigBoxyError(APIResponse): + def __init__(self, message: str): + self.status = "BOXY_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "message": self.message} + + +async def handle_create_or_update_third_party_config( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateOrUpdateThirdPartyConfigOkResult, + CreateOrUpdateThirdPartyConfigUnknownTenantError, + CreateOrUpdateThirdPartyConfigBoxyError, +]: + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is required") + provider_config = request_body.get("providerConfig") + + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return CreateOrUpdateThirdPartyConfigUnknownTenantError() + + if len(tenant_res.third_party_providers) == 0: + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers or [] + for provider in static_providers: + if ( + provider.include_in_non_public_tenants_by_default + or tenant_id == DEFAULT_TENANT_ID + ): + await create_or_update_third_party_config( + tenant_id, + ProviderConfig(third_party_id=provider.config.third_party_id), + None, + user_context, + ) + await asyncio.sleep(0.5) # 500ms delay + + if provider_config["thirdPartyId"].startswith("boxy-saml"): + boxy_url = provider_config["clients"][0]["additionalConfig"]["boxyURL"] + boxy_api_key = provider_config["clients"][0]["additionalConfig"]["boxyAPIKey"] + provider_config["clients"][0]["additionalConfig"]["boxyAPIKey"] = None + + if boxy_api_key and provider_config["clients"][0]["additionalConfig"][ + "samlInputType" + ] in [ + "xml", + "url", + ]: + request_body_input: Dict[str, Any] = { + "name": "", + "label": "", + "description": "", + "tenant": provider_config["clients"][0] + .get("additionalConfig", {}) + .get("boxyTenant") + or f"{tenant_id}-{provider_config['thirdPartyId']}", + "product": provider_config["clients"][0]["additionalConfig"].get( + "boxyProduct" + ) + or "supertokens", + "defaultRedirectUrl": provider_config["clients"][0]["additionalConfig"][ + "redirectURLs" + ][0], + "forceAuthn": False, + "encodedRawMetadata": encode_base64( + provider_config["clients"][0]["additionalConfig"].get("samlXML", "") + ), + "redirectUrl": json.dumps( + provider_config["clients"][0]["additionalConfig"]["redirectURLs"] + ), + "metadataUrl": provider_config["clients"][0]["additionalConfig"].get( + "samlURL", "" + ), + } + + normalised_domain = NormalisedURLDomain(boxy_url) + normalised_base_path = NormalisedURLPath(boxy_url) + connections_path = NormalisedURLPath("/api/v1/saml/config") + + status, resp = await do_post_request( + normalised_domain.get_as_string_dangerous() + + normalised_base_path.append( + connections_path + ).get_as_string_dangerous(), + body_params=request_body_input, + headers={"Authorization": f"Api-Key {boxy_api_key}"}, + ) + + if status != 200: + if status == 401: + return CreateOrUpdateThirdPartyConfigBoxyError("Invalid API Key") + return CreateOrUpdateThirdPartyConfigBoxyError( + resp.get("message", "Unknown error") + ) + + provider_config["clients"][0]["clientId"] = resp["clientID"] + provider_config["clients"][0]["clientSecret"] = resp["clientSecret"] + + third_party_res = await create_or_update_third_party_config( + tenant_id, provider_config, None, user_context + ) + + return CreateOrUpdateThirdPartyConfigOkResult(third_party_res.created_new) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py b/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py new file mode 100644 index 000000000..e84c5c873 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.multitenancy.asyncio import create_or_update_tenant +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class CreateTenantOkResult(APIResponse): + def __init__(self, created_new: bool): + self.status = "OK" + self.created_new = created_new + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "createdNew": self.created_new} + + +class CreateTenantMultitenancyNotEnabledError(APIResponse): + def __init__(self): + self.status = "MULTITENANCY_NOT_ENABLED_IN_CORE_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateTenantTenantIdAlreadyExistsError(APIResponse): + def __init__(self): + self.status = "TENANT_ID_ALREADY_EXISTS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateTenantInvalidTenantIdError(APIResponse): + def __init__(self, message: str): + self.status = "INVALID_TENANT_ID_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "message": self.message} + + +async def create_tenant( + _: APIInterface, + __: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateTenantOkResult, + CreateTenantMultitenancyNotEnabledError, + CreateTenantTenantIdAlreadyExistsError, + CreateTenantInvalidTenantIdError, +]: + request_body = await options.request.json() + if request_body is None: + raise BadInputError("Request body is required") + tenant_id = request_body.get("tenantId") + config = {k: v for k, v in request_body.items() if k != "tenantId"} + + if not isinstance(tenant_id, str) or tenant_id == "": + raise BadInputError("Missing required parameter 'tenantId'") + + try: + tenant_res = await create_or_update_tenant( + tenant_id, TenantConfigCreateOrUpdate.from_json(config), user_context + ) + except Exception as err: + err_msg: str = str(err) + if "SuperTokens core threw an error for a " in err_msg: + if "with status code: 402" in err_msg: + return CreateTenantMultitenancyNotEnabledError() + if "with status code: 400" in err_msg: + return CreateTenantInvalidTenantIdError( + err_msg.split(" and message: ")[1] + ) + raise err + + if not tenant_res.created_new: + return CreateTenantTenantIdAlreadyExistsError() + + return CreateTenantOkResult(tenant_res.created_new) diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 258718254..05aed583e 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -13,3 +13,5 @@ SEARCH_TAGS_API = "/api/search/tags" DASHBOARD_ANALYTICS_API = "/api/analytics" TENANTS_LIST_API = "/api/tenants/list" +TENANT_THIRD_PARTY_CONFIG_API = "/api/thirdparty/config" +TENANT_API = "/api/tenant" diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 5a5bd7b8f..6cb558d94 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -17,6 +17,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Dict, Any from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.dashboard.api.multitenancy.create_or_update_third_party_config import ( + handle_create_or_update_third_party_config, +) +from supertokens_python.recipe.dashboard.api.multitenancy.create_tenant import ( + create_tenant, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -72,6 +78,8 @@ USERS_LIST_GET_API, VALIDATE_KEY_API, TENANTS_LIST_API, + TENANT_THIRD_PARTY_CONFIG_API, + TENANT_API, ) from .utils import ( InputOverrideConfig, @@ -252,6 +260,20 @@ def get_apis_handled(self) -> List[APIHandled]: TENANTS_LIST_API, False, ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), + "post", + TENANT_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "put", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), ] async def handle_api_request( @@ -331,6 +353,12 @@ async def handle_api_request( api_function = handle_analytics_post elif request_id == TENANTS_LIST_API: api_function = handle_list_tenants_api + elif request_id == TENANT_API: + if method == "post": + api_function = create_tenant + elif request_id == TENANT_THIRD_PARTY_CONFIG_API: + if method == "put": + api_function = handle_create_or_update_third_party_config if api_function is not None: return await api_key_protector( diff --git a/supertokens_python/recipe/multitenancy/asyncio/__init__.py b/supertokens_python/recipe/multitenancy/asyncio/__init__.py index 53051de60..4f8020401 100644 --- a/supertokens_python/recipe/multitenancy/asyncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/asyncio/__init__.py @@ -29,6 +29,7 @@ AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, DisassociateUserFromTenantOkResult, + TenantConfigCreateOrUpdate, ) from ..recipe import MultitenancyRecipe @@ -38,7 +39,7 @@ async def create_or_update_tenant( tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Optional[Dict[str, Any]] = None, ) -> CreateOrUpdateTenantOkResult: if user_context is None: diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 4789a2c9e..4fc5ef551 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -43,6 +43,40 @@ def __init__( self.required_secondary_factors = required_secondary_factors self.third_party_providers = third_party_providers + @staticmethod + def from_json(json: Dict[str, Any]) -> TenantConfig: + return TenantConfig( + tenant_id=json.get("tenantId", ""), + third_party_providers=[ + ProviderConfig.from_json(provider) + for provider in json.get("thirdPartyProviders", []) + ], + core_config=json.get("coreConfig", {}), + first_factors=json.get("firstFactors", []), + required_secondary_factors=json.get("requiredSecondaryFactors", []), + ) + + +class TenantConfigCreateOrUpdate: + # pylint: disable=dangerous-default-value + def __init__( + self, + core_config: Dict[str, Any] = {}, + first_factors: Optional[List[str]] = None, + required_secondary_factors: Optional[List[str]] = None, + ): + self.core_config = core_config + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors + + @staticmethod + def from_json(json: Dict[str, Any]) -> TenantConfigCreateOrUpdate: + return TenantConfigCreateOrUpdate( + core_config=json.get("coreConfig", {}), + first_factors=json.get("firstFactors", []), + required_secondary_factors=json.get("requiredSecondaryFactors", []), + ) + class CreateOrUpdateTenantOkResult: status = "OK" @@ -123,7 +157,7 @@ async def get_tenant_id( async def create_or_update_tenant( self, tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: pass diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 8313c6491..432ffc07d 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -32,6 +32,7 @@ ListAllTenantsOkResult, CreateOrUpdateThirdPartyConfigOkResult, DeleteThirdPartyConfigOkResult, + TenantConfigCreateOrUpdate, ) if TYPE_CHECKING: @@ -127,7 +128,7 @@ async def get_tenant_id( async def create_or_update_tenant( self, tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: response = await self.querier.send_put_request( diff --git a/supertokens_python/recipe/multitenancy/syncio/__init__.py b/supertokens_python/recipe/multitenancy/syncio/__init__.py index e152c880d..7384ee6bd 100644 --- a/supertokens_python/recipe/multitenancy/syncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/syncio/__init__.py @@ -15,15 +15,16 @@ from typing import Any, Dict, Optional, TYPE_CHECKING from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate from supertokens_python.types import RecipeUserId if TYPE_CHECKING: - from ..interfaces import TenantConfig, ProviderConfig + from ..interfaces import ProviderConfig def create_or_update_tenant( tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index 568242ca4..769343a33 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -111,6 +111,17 @@ def to_json(self) -> Dict[str, Any]: return {k: v for k, v in res.items() if v is not None} + @staticmethod + def from_json(json: Dict[str, Any]) -> ProviderClientConfig: + return ProviderClientConfig( + client_id=json.get("clientId", ""), + client_secret=json.get("clientSecret", None), + client_type=json.get("clientType", None), + scope=json.get("scope", None), + force_pkce=json.get("forcePKCE", None), + additional_config=json.get("additionalConfig", None), + ) + class UserFields: def __init__( @@ -132,6 +143,14 @@ def to_json(self) -> Dict[str, Any]: return {k: v for k, v in res.items() if v is not None} + @staticmethod + def from_json(json: Dict[str, Any]) -> UserFields: + return UserFields( + user_id=json.get("userId", None), + email=json.get("email", None), + email_verified=json.get("emailVerified", None), + ) + class UserInfoMap: def __init__( @@ -150,6 +169,15 @@ def to_json(self) -> Dict[str, Any]: res["fromUserInfoAPI"] = self.from_user_info_api.to_json() return res + @staticmethod + def from_json(json: Dict[str, Any]) -> UserInfoMap: + return UserInfoMap( + from_id_token_payload=UserFields.from_json( + json.get("fromIdTokenPayload", {}) + ), + from_user_info_api=UserFields.from_json(json.get("fromUserInfoAPI", {})), + ) + class CommonProviderConfig: def __init__( @@ -213,9 +241,9 @@ def to_json(self) -> Dict[str, Any]: "userInfoEndpointHeaders": self.user_info_endpoint_headers, "jwksURI": self.jwks_uri, "oidcDiscoveryEndpoint": self.oidc_discovery_endpoint, - "userInfoMap": self.user_info_map.to_json() - if self.user_info_map is not None - else None, + "userInfoMap": ( + self.user_info_map.to_json() if self.user_info_map is not None else None + ), "requireEmail": self.require_email, } @@ -362,6 +390,32 @@ def to_json(self) -> Dict[str, Any]: return d + @staticmethod + def from_json(json: Dict[str, Any]) -> ProviderConfig: + return ProviderConfig( + third_party_id=json.get("thirdPartyId", ""), + name=json.get("name", ""), + clients=[ + ProviderClientConfig.from_json(c) for c in json.get("clients", []) + ], + authorization_endpoint=json.get("authorizationEndpoint", ""), + authorization_endpoint_query_params=json.get( + "authorizationEndpointQueryParams", {} + ), + token_endpoint=json.get("tokenEndpoint", ""), + token_endpoint_body_params=json.get("tokenEndpointBodyParams", {}), + user_info_endpoint=json.get("userInfoEndpoint", ""), + user_info_endpoint_query_params=json.get("userInfoEndpointQueryParams", {}), + user_info_endpoint_headers=json.get("userInfoEndpointHeaders", {}), + jwks_uri=json.get("jwksURI", ""), + oidc_discovery_endpoint=json.get("oidcDiscoveryEndpoint", ""), + user_info_map=UserInfoMap.from_json(json.get("userInfoMap", {})), + require_email=json.get("requireEmail", None), + validate_id_token_payload=None, + generate_fake_email=None, + validate_access_token=None, + ) + class ProviderInput: def __init__( diff --git a/supertokens_python/recipe/thirdparty/providers/utils.py b/supertokens_python/recipe/thirdparty/providers/utils.py index 9f9aaf80b..1f274461a 100644 --- a/supertokens_python/recipe/thirdparty/providers/utils.py +++ b/supertokens_python/recipe/thirdparty/providers/utils.py @@ -48,7 +48,7 @@ async def do_get_request( async def do_post_request( url: str, - body_params: Optional[Dict[str, str]] = None, + body_params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Tuple[int, Dict[str, Any]]: if body_params is None: @@ -64,7 +64,10 @@ async def do_post_request( log_debug_message( "Received response with status %s and body %s", res.status_code, res.text ) - return res.status_code, res.json() + try: + return res.status_code, res.json() + except Exception: + return res.status_code, {"message": res.text} def normalise_oidc_endpoint_to_include_well_known(url: str) -> str: diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 7e3c43e44..26fc09be6 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -183,6 +183,13 @@ def utf_base64decode(s: str, urlsafe: bool) -> str: return b64decode(s.encode("utf-8")).decode("utf-8") +def encode_base64(value: str) -> str: + """ + Encode the passed value to base64 and return the encoded value. + """ + return b64encode(value.encode()).decode() + + def get_filtered_list(func: Callable[[_T], bool], given_list: List[_T]) -> List[_T]: return list(filter(func, given_list)) diff --git a/tests/emailpassword/test_multitenancy.py b/tests/emailpassword/test_multitenancy.py index 12d461b0a..9b7e521e0 100644 --- a/tests/emailpassword/test_multitenancy.py +++ b/tests/emailpassword/test_multitenancy.py @@ -29,7 +29,9 @@ SignInOkResult, CreateResetPasswordOkResult, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args @@ -62,9 +64,15 @@ async def test_multitenancy_in_emailpassword(): setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) - await create_or_update_tenant("t2", TenantConfig(first_factors=["emailpassword"])) - await create_or_update_tenant("t3", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) user1 = await sign_up("t1", "test@example.com", "password1") user2 = await sign_up("t2", "test@example.com", "password2") diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 53d75e7f9..7127e0735 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -46,7 +46,9 @@ ) from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.thirdparty.provider import ( ProviderConfig, ProviderClientConfig, @@ -69,12 +71,18 @@ async def test_tenant_crud(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) await create_or_update_tenant( "t2", - TenantConfig(first_factors=["otp-email, otp-phone, link-email, link-phone"]), + TenantConfigCreateOrUpdate( + first_factors=["otp-email, otp-phone, link-email, link-phone"] + ), + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) ) - await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) tenants = await list_all_tenants() assert len(tenants.tenants) == 3 @@ -115,7 +123,7 @@ async def test_tenant_crud(): # update tenant1 to add passwordless: await create_or_update_tenant( "t1", - TenantConfig( + TenantConfigCreateOrUpdate( first_factors=[ "otp-email", "otp-phone", @@ -136,7 +144,9 @@ async def test_tenant_crud(): assert t1_config.core_config == {} # update tenant1 to add thirdparty: - await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) t1_config = await get_tenant("t1") assert t1_config is not None assert t1_config.first_factors is not None @@ -161,7 +171,9 @@ async def test_tenant_thirdparty_config(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) await create_or_update_third_party_config( "t1", config=ProviderConfig( @@ -288,14 +300,18 @@ async def test_user_association_and_disassociation_with_tenants(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) await create_or_update_tenant( "t2", - TenantConfig( + TenantConfigCreateOrUpdate( first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] ), ) - await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) signup_response = await sign_up("public", "test@example.com", "password1") assert isinstance(signup_response, SignUpOkResult) diff --git a/tests/passwordless/test_mutlitenancy.py b/tests/passwordless/test_mutlitenancy.py index a6e1fed35..c1ff18bdf 100644 --- a/tests/passwordless/test_mutlitenancy.py +++ b/tests/passwordless/test_mutlitenancy.py @@ -23,7 +23,9 @@ consume_code, ConsumeCodeOkResult, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args @@ -59,19 +61,19 @@ async def test_multitenancy_functions(): await create_or_update_tenant( "t1", - TenantConfig( + TenantConfigCreateOrUpdate( first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] ), ) await create_or_update_tenant( "t2", - TenantConfig( + TenantConfigCreateOrUpdate( first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] ), ) await create_or_update_tenant( "t3", - TenantConfig( + TenantConfigCreateOrUpdate( first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] ), ) diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index 03bdd342f..ff2aac785 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -1,7 +1,7 @@ from flask import Flask, request, jsonify from supertokens_python.recipe.multitenancy.interfaces import ( AssociateUserToTenantOkResult, - TenantConfig, + TenantConfigCreateOrUpdate, ) import supertokens_python.recipe.multitenancy.syncio as multitenancy from supertokens_python.recipe.thirdparty import ( @@ -22,10 +22,9 @@ def create_or_update_tenant(): # type: ignore config = data["config"] user_context = data.get("userContext") - config = TenantConfig( + config = TenantConfigCreateOrUpdate( first_factors=config.get("firstFactors"), required_secondary_factors=config.get("requiredSecondaryFactors"), - third_party_providers=config.get("thirdPartyProviders"), core_config=config.get("coreConfig"), ) diff --git a/tests/thirdparty/test_multitenancy.py b/tests/thirdparty/test_multitenancy.py index 113ba9b38..4fbd61aa6 100644 --- a/tests/thirdparty/test_multitenancy.py +++ b/tests/thirdparty/test_multitenancy.py @@ -23,7 +23,9 @@ manually_create_or_update_user, get_provider, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.thirdparty.interfaces import ( ManuallyCreateOrUpdateUserOkResult, ) @@ -53,9 +55,15 @@ async def test_thirtyparty_multitenancy_functions(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) - await create_or_update_tenant("t2", TenantConfig(first_factors=["thirdparty"])) - await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) # sign up: user1a = await manually_create_or_update_user( @@ -215,9 +223,15 @@ async def test_get_provider(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["thirdparty"])) - await create_or_update_tenant("t2", TenantConfig(first_factors=["thirdparty"])) - await create_or_update_tenant("t3", TenantConfig(first_factors=["thirdparty"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) await create_or_update_third_party_config( "t1", diff --git a/tests/userroles/test_multitenancy.py b/tests/userroles/test_multitenancy.py index b67578172..51cb5d133 100644 --- a/tests/userroles/test_multitenancy.py +++ b/tests/userroles/test_multitenancy.py @@ -20,7 +20,9 @@ ) from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.userroles.asyncio import ( create_new_role_or_add_permissions, add_role_to_user, @@ -57,9 +59,15 @@ async def test_multitenancy_in_user_roles(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(first_factors=["emailpassword"])) - await create_or_update_tenant("t2", TenantConfig(first_factors=["emailpassword"])) - await create_or_update_tenant("t3", TenantConfig(first_factors=["emailpassword"])) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) user = await sign_up("public", "test@example.com", "password1") assert isinstance(user, SignUpOkResult) From d123e191112b7e5c96e91c4050279cc7fb7761c3 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 13:00:16 +0530 Subject: [PATCH 033/126] adds one more dahsboard api --- .../api/multitenancy/delete_tenant.py | 57 +++++++++++++++++++ supertokens_python/recipe/dashboard/recipe.py | 11 ++++ 2 files changed, 68 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py b/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py new file mode 100644 index 000000000..f14bf1e1f --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from typing_extensions import Literal +from supertokens_python.recipe.multitenancy.asyncio import delete_tenant +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class DeleteTenantOkResult(APIResponse): + def __init__(self, did_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_exist = did_exist + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "didExist": self.did_exist} + + +class DeleteTenantCannotDeletePublicTenantError(APIResponse): + def __init__(self): + self.status: Literal[ + "CANNOT_DELETE_PUBLIC_TENANT_ERROR" + ] = "CANNOT_DELETE_PUBLIC_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def delete_tenant_api( + _: APIInterface, + tenant_id: str, + __: APIOptions, + user_context: Dict[str, Any], +) -> Union[DeleteTenantOkResult, DeleteTenantCannotDeletePublicTenantError]: + try: + delete_tenant_res = await delete_tenant(tenant_id, user_context) + return DeleteTenantOkResult(delete_tenant_res.did_exist) + except Exception as err: + err_msg: str = str(err) + if ( + "SuperTokens core threw an error for a " in err_msg + and "with status code: 403" in err_msg + ): + return DeleteTenantCannotDeletePublicTenantError() + raise err diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 6cb558d94..5e23c8ffd 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -23,6 +23,9 @@ from supertokens_python.recipe.dashboard.api.multitenancy.create_tenant import ( create_tenant, ) +from supertokens_python.recipe.dashboard.api.multitenancy.delete_tenant import ( + delete_tenant_api, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -266,6 +269,12 @@ def get_apis_handled(self) -> List[APIHandled]: TENANT_API, False, ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), + "delete", + TENANT_API, + False, + ), APIHandled( NormalisedURLPath( get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) @@ -356,6 +365,8 @@ async def handle_api_request( elif request_id == TENANT_API: if method == "post": api_function = create_tenant + if method == "delete": + api_function = delete_tenant_api elif request_id == TENANT_THIRD_PARTY_CONFIG_API: if method == "put": api_function = handle_create_or_update_third_party_config From 09f54cfdc46a7b93728d5a7de58d201f212c906a Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 15:42:50 +0530 Subject: [PATCH 034/126] more apis --- .../multitenancy/delete_third_party_config.py | 123 +++++++++++++ .../api/multitenancy/get_tenant_info.py | 168 ++++++++++++++++++ .../dashboard/api/multitenancy/utils.py | 40 +++++ .../recipe/dashboard/interfaces.py | 58 ++++++ supertokens_python/recipe/dashboard/recipe.py | 32 ++++ 5 files changed, 421 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/utils.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py new file mode 100644 index 000000000..84688e19c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multifactorauth.types import FactorIds +from supertokens_python.recipe.multitenancy.asyncio import ( + get_tenant, + create_or_update_third_party_config, + create_or_update_tenant, + delete_third_party_config, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +import asyncio + + +class DeleteThirdPartyConfigOkResult(APIResponse): + def __init__(self, did_config_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_config_exist = did_config_exist + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "didConfigExist": self.did_config_exist} + + +class DeleteThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def delete_third_party_config_api( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[DeleteThirdPartyConfigOkResult, DeleteThirdPartyConfigUnknownTenantError]: + third_party_id = options.request.get_query_param("thirdPartyId") + + if not tenant_id or not third_party_id: + raise_bad_input_exception( + "Missing required parameter 'tenantId' or 'thirdPartyId'" + ) + + assert third_party_id is not None + + tenant_res = await get_tenant(tenant_id, user_context) + if tenant_res is None: + return DeleteThirdPartyConfigUnknownTenantError() + + third_party_ids_from_core = [ + provider.third_party_id for provider in tenant_res.third_party_providers + ] + + if len(third_party_ids_from_core) == 0: + # This means that the tenant was using the static list of providers, we need to add them all before deleting one + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = ( + mt_recipe.static_third_party_providers if mt_recipe.config else [] + ) + static_provider_ids = [ + provider.config.third_party_id for provider in static_providers + ] + + for provider_id in static_provider_ids: + await create_or_update_third_party_config( + tenant_id, + ProviderConfig(third_party_id=provider_id), + None, + user_context, + ) + # Delay after each provider to avoid rate limiting + await asyncio.sleep(0.5) # 500ms + elif ( + len(third_party_ids_from_core) == 1 + and third_party_ids_from_core[0] == third_party_id + ): + if tenant_res.first_factors is None: + # Add all static first factors except thirdparty + await create_or_update_tenant( + tenant_id, + TenantConfigCreateOrUpdate( + first_factors=[ + FactorIds.EMAILPASSWORD, + FactorIds.OTP_PHONE, + FactorIds.OTP_EMAIL, + FactorIds.LINK_PHONE, + FactorIds.LINK_EMAIL, + ] + ), + user_context, + ) + elif "thirdparty" in tenant_res.first_factors: + # Add all static first factors except thirdparty + new_first_factors = [ + factor for factor in tenant_res.first_factors if factor != "thirdparty" + ] + await create_or_update_tenant( + tenant_id, + TenantConfigCreateOrUpdate(first_factors=new_first_factors), + user_context, + ) + + result = await delete_third_party_config(tenant_id, third_party_id, user_context) + return DeleteThirdPartyConfigOkResult(result.did_config_exist) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py new file mode 100644 index 000000000..17f42b7e7 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Union +from typing_extensions import Literal + +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python import Supertokens +from supertokens_python.types import APIResponse +from .utils import ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) +from supertokens_python.recipe.thirdparty.providers.config_utils import ( + find_and_create_provider_instance, + merge_providers_from_core_and_static, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from ...interfaces import APIInterface, APIOptions, CoreConfigFieldInfo + + +from typing import List, Optional + + +class ThirdPartyProvider: + def __init__(self, third_party_id: str, name: str): + self.third_party_id = third_party_id + self.name = name + + def to_json(self) -> Dict[str, Any]: + return {"thirdPartyId": self.third_party_id, "name": self.name} + + +class TenantInfo: + def __init__( + self, + tenant_id: str, + third_party: List[ThirdPartyProvider], + first_factors: List[str], + required_secondary_factors: Optional[List[str]], + core_config: List[CoreConfigFieldInfo], + user_count: int, + ): + self.tenant_id = tenant_id + self.third_party = third_party + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors + self.core_config = core_config + self.user_count = user_count + + def to_json(self) -> Dict[str, Any]: + return { + "tenantId": self.tenant_id, + "thirdParty": { + "providers": [provider.to_json() for provider in self.third_party] + }, + "firstFactors": self.first_factors, + "requiredSecondaryFactors": self.required_secondary_factors, + "coreConfig": [field.to_json() for field in self.core_config], + "userCount": self.user_count, + } + + +class GetTenantInfoOkResult(APIResponse): + def __init__(self, tenant: TenantInfo): + self.status: Literal["OK"] = "OK" + self.tenant = tenant + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "tenant": self.tenant.to_json()} + + +class GetTenantInfoUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def get_tenant_info( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[GetTenantInfoOkResult, GetTenantInfoUnknownTenantError]: + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return GetTenantInfoUnknownTenantError() + + first_factors = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + ) + + user_count = await Supertokens.get_instance().get_user_count( + None, tenant_id, user_context + ) + + providers_from_core = tenant_res.third_party_providers + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers + + merged_providers_from_core_and_static = merge_providers_from_core_and_static( + providers_from_core, static_providers, tenant_id == DEFAULT_TENANT_ID + ) + + querier = Querier.get_instance(options.recipe_id) + core_config = await querier.send_get_request( + NormalisedURLPath(f"/{tenant_id}/recipe/dashboard/tenant/core-config"), + {}, + user_context, + ) + + providers: List[ThirdPartyProvider] = [] + for provider in merged_providers_from_core_and_static: + try: + provider_instance = await find_and_create_provider_instance( + merged_providers_from_core_and_static, + provider.config.third_party_id, + ( + provider.config.clients[0].client_type + if provider.config.clients + else None + ), + user_context, + ) + assert provider_instance is not None + if provider_instance.config.name is None: + raise Exception("Falling back to exception block") + providers.append( + ThirdPartyProvider( + provider.config.third_party_id, + provider_instance.config.name, + ), + ) + except Exception: + providers.append( + ThirdPartyProvider( + provider.config.third_party_id, provider.config.third_party_id + ) + ) + + tenant = TenantInfo( + tenant_id, + providers, + first_factors, + tenant_res.required_secondary_factors, + [CoreConfigFieldInfo.from_json(field) for field in core_config["config"]], + user_count, + ) + + return GetTenantInfoOkResult(tenant) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/utils.py b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py new file mode 100644 index 000000000..48772d25d --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py @@ -0,0 +1,40 @@ +from typing import List +from supertokens_python.recipe.multifactorauth.utils import ( + is_factor_configured_for_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + + +def get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_details_from_core: TenantConfig, +) -> List[str]: + first_factors: List[str] + + mt_instance = MultitenancyRecipe.get_instance() + + if tenant_details_from_core.first_factors is not None: + first_factors = ( + tenant_details_from_core.first_factors + ) # highest priority, config from core + elif mt_instance.static_first_factors is not None: + first_factors = mt_instance.static_first_factors # next priority, static config + else: + # Fallback to all available factors (de-duplicated) + first_factors = list(set(mt_instance.all_available_first_factors)) + + # we now filter out all available first factors by checking if they are valid because + # we want to return the ones that can work. this would be based on what recipes are enabled + # on the core and also first_factors configured in the core and supertokens.init + # Also, this way, in the front end, the developer can just check for first_factors for + # enabled recipes in all cases irrespective of whether they are using MFA or not + valid_first_factors: List[str] = [] + for factor_id in first_factors: + if is_factor_configured_for_tenant( + all_available_first_factors=mt_instance.all_available_first_factors, + first_factors=first_factors, + factor_id=factor_id, + ): + valid_first_factors.append(factor_id) + + return valid_first_factors diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 53df121c5..9c60692fa 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union +from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.types import AccountLinkingUser @@ -336,3 +337,60 @@ class AnalyticsResponse(APIResponse): def to_json(self) -> Dict[str, Any]: return {"status": self.status} + + +class CoreConfigFieldInfo: + def __init__( + self, + key: str, + value_type: Literal["string", "boolean", "number"], + value: Union[str, int, float, bool, None], + description: str, + is_different_across_tenants: bool, + possible_values: Union[List[str], None] = None, + is_nullable: bool = False, + default_value: Union[str, int, float, bool, None] = None, + is_plugin_property: bool = False, + is_plugin_property_editable: bool = False, + ): + self.key = key + self.value_type = value_type + self.value = value + self.description = description + self.is_different_across_tenants = is_different_across_tenants + self.possible_values = possible_values + self.is_nullable = is_nullable + self.default_value = default_value + self.is_plugin_property = is_plugin_property + self.is_plugin_property_editable = is_plugin_property_editable + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "key": self.key, + "valueType": self.value_type, + "value": self.value, + "description": self.description, + "isDifferentAcrossTenants": self.is_different_across_tenants, + "isNullable": self.is_nullable, + "defaultValue": self.default_value, + "isPluginProperty": self.is_plugin_property, + "isPluginPropertyEditable": self.is_plugin_property_editable, + } + if self.possible_values is not None: + result["possibleValues"] = self.possible_values + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> CoreConfigFieldInfo: + return CoreConfigFieldInfo( + key=json["key"], + value_type=json["valueType"], + value=json["value"], + description=json["description"], + is_different_across_tenants=json["isDifferentAcrossTenants"], + possible_values=json["possibleValues"], + is_nullable=json["isNullable"], + default_value=json["defaultValue"], + is_plugin_property=json["isPluginProperty"], + is_plugin_property_editable=json["isPluginPropertyEditable"], + ) diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 5e23c8ffd..32040a20a 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -26,6 +26,12 @@ from supertokens_python.recipe.dashboard.api.multitenancy.delete_tenant import ( delete_tenant_api, ) +from supertokens_python.recipe.dashboard.api.multitenancy.delete_third_party_config import ( + delete_third_party_config_api, +) +from supertokens_python.recipe.dashboard.api.multitenancy.get_tenant_info import ( + get_tenant_info, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -275,6 +281,12 @@ def get_apis_handled(self) -> List[APIHandled]: TENANT_API, False, ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), + "get", + TENANT_API, + False, + ), APIHandled( NormalisedURLPath( get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) @@ -283,6 +295,22 @@ def get_apis_handled(self) -> List[APIHandled]: TENANT_THIRD_PARTY_CONFIG_API, False, ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "delete", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "get", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), ] async def handle_api_request( @@ -367,9 +395,13 @@ async def handle_api_request( api_function = create_tenant if method == "delete": api_function = delete_tenant_api + if method == "get": + api_function = get_tenant_info elif request_id == TENANT_THIRD_PARTY_CONFIG_API: if method == "put": api_function = handle_create_or_update_third_party_config + if method == "delete": + api_function = delete_third_party_config_api if api_function is not None: return await api_key_protector( From 713efb5a574525793a94f843b388a7134ce3d7cc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 17:58:09 +0530 Subject: [PATCH 035/126] more apis --- .../multitenancy/get_third_party_config.py | 378 ++++++++++++++++++ supertokens_python/recipe/dashboard/recipe.py | 5 + 2 files changed, 383 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py new file mode 100644 index 000000000..375e3f065 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py @@ -0,0 +1,378 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.thirdparty import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, +) +from supertokens_python.recipe.thirdparty.provider import CommonProviderConfig, Provider +from supertokens_python.recipe.thirdparty.providers.utils import do_get_request +from supertokens_python.types import APIResponse +from supertokens_python.recipe.thirdparty.providers.config_utils import ( + find_and_create_provider_instance, + merge_providers_from_core_and_static, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from ...interfaces import APIInterface, APIOptions + + +class ProviderConfigResponse(APIResponse): + def __init__( + self, + provider_config: ProviderConfig, + is_get_authorisation_redirect_url_overridden: bool, + is_exchange_auth_code_for_oauth_tokens_overridden: bool, + is_get_user_info_overridden: bool, + ): + self.provider_config = provider_config + self.is_get_authorisation_redirect_url_overridden = ( + is_get_authorisation_redirect_url_overridden + ) + self.is_exchange_auth_code_for_oauth_tokens_overridden = ( + is_exchange_auth_code_for_oauth_tokens_overridden + ) + self.is_get_user_info_overridden = is_get_user_info_overridden + + def to_json(self) -> Dict[str, Any]: + json_response = self.provider_config.to_json() + json_response[ + "isGetAuthorisationRedirectUrlOverridden" + ] = self.is_get_authorisation_redirect_url_overridden + json_response[ + "isExchangeAuthCodeForOAuthTokensOverridden" + ] = self.is_exchange_auth_code_for_oauth_tokens_overridden + json_response["isGetUserInfoOverridden"] = self.is_get_user_info_overridden + json_response["status"] = "OK" + return json_response + + +class GetThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def get_third_party_config( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ProviderConfigResponse, GetThirdPartyConfigUnknownTenantError]: + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return GetThirdPartyConfigUnknownTenantError() + + third_party_id = options.request.get_query_param("thirdPartyId") + + if third_party_id is None: + raise_bad_input_exception("Please provide thirdPartyId") + + providers_from_core = tenant_res.third_party_providers + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers or [] + + additional_config: Optional[Dict[str, Any]] = None + + providers_from_core = [ + provider + for provider in providers_from_core + if provider.third_party_id == third_party_id + ] + + if not providers_from_core: + providers_from_core.append(ProviderConfig(third_party_id=third_party_id)) + + if third_party_id in ["okta", "active-directory", "boxy-saml", "google-workspaces"]: + if third_party_id == "okta": + okta_domain = options.request.get_query_param("oktaDomain") + if okta_domain is not None: + additional_config = {"oktaDomain": okta_domain} + elif third_party_id == "active-directory": + directory_id = options.request.get_query_param("directoryId") + if directory_id is not None: + additional_config = {"directoryId": directory_id} + elif third_party_id == "boxy-saml": + boxy_url = options.request.get_query_param("boxyUrl") + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_url is not None: + additional_config = {"boxyURL": boxy_url} + if boxy_api_key is not None: + additional_config["boxyAPIKey"] = boxy_api_key + elif third_party_id == "google-workspaces": + hd = options.request.get_query_param("hd") + if hd is not None: + additional_config = {"hd": hd} + + if additional_config is not None: + providers_from_core[0].oidc_discovery_endpoint = None + providers_from_core[0].authorization_endpoint = None + providers_from_core[0].token_endpoint = None + providers_from_core[0].user_info_endpoint = None + + if providers_from_core[0].clients is not None: + for existing_client in providers_from_core[0].clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + providers_from_core[0].clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + static_providers = [ + provider + for provider in static_providers + if provider.config.third_party_id == third_party_id + ] + + if not static_providers and third_party_id == "apple": + static_providers.append( + ProviderInput( + config=ProviderConfig( + third_party_id="apple", + clients=[ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id" + ) + ], + ) + ) + ) + + additional_config = { + "teamId": "", + "keyId": "", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + } + + if len(static_providers) == 1 and additional_config is not None: + static_providers[0].config.oidc_discovery_endpoint = None + static_providers[0].config.authorization_endpoint = None + static_providers[0].config.token_endpoint = None + static_providers[0].config.user_info_endpoint = None + if static_providers[0].config.clients is not None: + for existing_client in static_providers[0].config.clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + static_providers[0].config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + merged_providers_from_core_and_static = merge_providers_from_core_and_static( + providers_from_core, static_providers, True + ) + + if len(merged_providers_from_core_and_static) != 1: + raise Exception("should never come here!") + + for merged_provider in merged_providers_from_core_and_static: + if merged_provider.config.third_party_id == third_party_id: + if not merged_provider.config.clients: + merged_provider.config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=( + additional_config if additional_config is not None else None + ), + ) + ] + clients: List[ProviderClientConfig] = [] + common_provider_config: CommonProviderConfig = CommonProviderConfig( + third_party_id=third_party_id + ) + is_get_authorisation_redirect_url_overridden = False + is_exchange_auth_code_for_oauth_tokens_overridden = False + is_get_user_info_overridden = False + + for provider in merged_providers_from_core_and_static: + if provider.config.third_party_id == third_party_id: + found_correct_config = False + + for client in provider.config.clients or []: + try: + provider_instance = await find_and_create_provider_instance( + merged_providers_from_core_and_static, + third_party_id, + client.client_type, + user_context, + ) + assert provider_instance is not None + clients.append( + ProviderClientConfig( + client_id=provider_instance.config.client_id, + client_secret=provider_instance.config.client_secret, + scope=provider_instance.config.scope, + client_type=provider_instance.config.client_type, + additional_config=provider_instance.config.additional_config, + force_pkce=provider_instance.config.force_pkce, + ) + ) + # common_provider_config = CommonProviderConfig( + # third_party_id=provider_instance.config.third_party_id, + # name=provider_instance.config.name, + # authorization_endpoint=provider_instance.config.authorization_endpoint, + # authorization_endpoint_query_params=provider_instance.config.authorization_endpoint_query_params, + # token_endpoint=provider_instance.config.token_endpoint, + # token_endpoint_body_params=provider_instance.config.token_endpoint_body_params, + # user_info_endpoint=provider_instance.config.user_info_endpoint, + # user_info_endpoint_query_params=provider_instance.config.user_info_endpoint_query_params, + # user_info_endpoint_headers=provider_instance.config.user_info_endpoint_headers, + # jwks_uri=provider_instance.config.jwks_uri, + # oidc_discovery_endpoint=provider_instance.config.oidc_discovery_endpoint, + # user_info_map=provider_instance.config.user_info_map, + # require_email=provider_instance.config.require_email, + # validate_id_token_payload=provider_instance.config.validate_id_token_payload, + # validate_access_token=provider_instance.config.validate_access_token, + # generate_fake_email=provider_instance.config.generate_fake_email, + # ) + common_provider_config = provider_instance.config + + if provider.override is not None: + before_override = Provider( + config=provider_instance.config, + id=provider_instance.id, + ) + after_override = provider.override(before_override) + + if ( + before_override.get_authorisation_redirect_url + != after_override.get_authorisation_redirect_url + ): + is_get_authorisation_redirect_url_overridden = True + if ( + before_override.exchange_auth_code_for_oauth_tokens + != after_override.exchange_auth_code_for_oauth_tokens + ): + is_exchange_auth_code_for_oauth_tokens_overridden = True + if ( + before_override.get_user_info + != after_override.get_user_info + ): + is_get_user_info_overridden = True + + found_correct_config = True + except Exception: + clients.append(client) + + if not found_correct_config: + common_provider_config = provider.config + + break + + if additional_config and "privateKey" in additional_config: + additional_config["privateKey"] = "" + + temp_clients = [ + client + for client in clients + if client.client_id == "nonguessable-temporary-client-id" + ] + + final_clients = [ + client + for client in clients + if client.client_id != "nonguessable-temporary-client-id" + ] + if not final_clients: + final_clients = [ + ProviderClientConfig( + client_id="", + client_secret="", + additional_config=additional_config, + client_type=temp_clients[0].client_type, + force_pkce=temp_clients[0].force_pkce, + scope=temp_clients[0].scope, + ) + ] + + if third_party_id.startswith("boxy-saml"): + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_api_key and final_clients[0].client_id: + assert isinstance(final_clients[0].additional_config, dict) + boxy_url = final_clients[0].additional_config["boxyURL"] + normalised_domain = NormalisedURLDomain(boxy_url) + normalised_base_path = NormalisedURLPath(boxy_url) + connections_path = NormalisedURLPath("/api/v1/saml/config") + + resp = await do_get_request( + normalised_domain.get_as_string_dangerous() + + normalised_base_path.append( + connections_path + ).get_as_string_dangerous(), + {"clientID": final_clients[0].client_id}, + {"Authorization": f"Api-Key {boxy_api_key}"}, + ) + + json_response = resp + final_clients[0].additional_config.update( + { + "redirectURLs": json_response["redirectUrl"], + "boxyTenant": json_response["tenant"], + "boxyProduct": json_response["product"], + } + ) + + provider_config = ProviderConfig( + third_party_id=third_party_id, + clients=final_clients, + authorization_endpoint=common_provider_config.authorization_endpoint, + authorization_endpoint_query_params=common_provider_config.authorization_endpoint_query_params, + token_endpoint=common_provider_config.token_endpoint, + token_endpoint_body_params=common_provider_config.token_endpoint_body_params, + user_info_endpoint=common_provider_config.user_info_endpoint, + user_info_endpoint_query_params=common_provider_config.user_info_endpoint_query_params, + user_info_endpoint_headers=common_provider_config.user_info_endpoint_headers, + jwks_uri=common_provider_config.jwks_uri, + oidc_discovery_endpoint=common_provider_config.oidc_discovery_endpoint, + user_info_map=common_provider_config.user_info_map, + require_email=common_provider_config.require_email, + validate_id_token_payload=common_provider_config.validate_id_token_payload, + validate_access_token=common_provider_config.validate_access_token, + generate_fake_email=common_provider_config.generate_fake_email, + name=common_provider_config.name, + ) + + return ProviderConfigResponse( + provider_config=provider_config, + is_get_authorisation_redirect_url_overridden=is_get_authorisation_redirect_url_overridden, + is_exchange_auth_code_for_oauth_tokens_overridden=is_exchange_auth_code_for_oauth_tokens_overridden, + is_get_user_info_overridden=is_get_user_info_overridden, + ) diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 32040a20a..fc7761f9f 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -32,6 +32,9 @@ from supertokens_python.recipe.dashboard.api.multitenancy.get_tenant_info import ( get_tenant_info, ) +from supertokens_python.recipe.dashboard.api.multitenancy.get_third_party_config import ( + get_third_party_config, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -402,6 +405,8 @@ async def handle_api_request( api_function = handle_create_or_update_third_party_config if method == "delete": api_function = delete_third_party_config_api + if method == "get": + api_function = get_third_party_config if api_function is not None: return await api_key_protector( From da0348809d447b83350b165bf7f73221e05fba2b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 16 Sep 2024 20:05:19 +0530 Subject: [PATCH 036/126] more apis --- .../list_all_tenants_with_login_methods.py | 72 +++++++++ .../multitenancy/update_tenant_core_config.py | 98 ++++++++++++ .../update_tenant_first_factor.py | 116 ++++++++++++++ .../update_tenant_secondary_factor.py | 147 ++++++++++++++++++ .../dashboard/api/multitenancy/utils.py | 62 ++++++++ .../recipe/dashboard/constants.py | 4 + supertokens_python/recipe/dashboard/recipe.py | 58 +++++++ 7 files changed, 557 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py create mode 100644 supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py b/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py new file mode 100644 index 000000000..95300b99a --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class TenantWithLoginMethods: + def __init__(self, tenant_id: str, first_factors: List[str]): + self.tenant_id = tenant_id + self.first_factors = first_factors + + +class ListAllTenantsWithLoginMethodsOkResult(APIResponse): + def __init__(self, tenants: List[TenantWithLoginMethods]): + self.status = "OK" + self.tenants = tenants + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "tenants": [ + {"tenantId": tenant.tenant_id, "firstFactors": tenant.first_factors} + for tenant in self.tenants + ], + } + + +async def list_all_tenants_with_login_methods( + _: APIInterface, + __: str, + ___: APIOptions, + user_context: Dict[str, Any], +) -> ListAllTenantsWithLoginMethodsOkResult: + tenants_res = ( + await MultitenancyRecipe.get_instance().recipe_implementation.list_all_tenants( + user_context + ) + ) + final_tenants: List[TenantWithLoginMethods] = [] + + for current_tenant in tenants_res.tenants: + login_methods = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + current_tenant + ) + ) + + final_tenants.append( + TenantWithLoginMethods( + tenant_id=current_tenant.tenant_id, + first_factors=login_methods, + ) + ) + + return ListAllTenantsWithLoginMethodsOkResult(tenants=final_tenants) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py new file mode 100644 index 000000000..dfc279d42 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class UpdateTenantCoreConfigOkResult(APIResponse): + status: Literal["OK"] = "OK" + + def __init__(self): + self.status = "OK" + + def to_json(self) -> Dict[str, Literal["OK"]]: + return {"status": self.status} + + +class UpdateTenantCoreConfigUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +class UpdateTenantCoreConfigInvalidConfigErrorResult(APIResponse): + status: Literal["INVALID_CONFIG_ERROR"] = "INVALID_CONFIG_ERROR" + + def __init__(self, message: str): + self.status = "INVALID_CONFIG_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Union[Literal["INVALID_CONFIG_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +async def update_tenant_core_config( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantCoreConfigOkResult, + UpdateTenantCoreConfigUnknownTenantErrorResult, + UpdateTenantCoreConfigInvalidConfigErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + name = request_body["name"] + value = request_body["value"] + + mt_recipe = MultitenancyRecipe.get_instance() + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + if tenant_res is None: + return UpdateTenantCoreConfigUnknownTenantErrorResult() + + try: + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + core_config={name: value}, + ), + user_context=user_context, + ) + except Exception as err: + err_msg = str(err) + if ( + "SuperTokens core threw an error for a " in err_msg + and "with status code: 400" in err_msg + ): + return UpdateTenantCoreConfigInvalidConfigErrorResult( + message=err_msg.split(" and message: ")[1] + ) + raise err + + return UpdateTenantCoreConfigOkResult() diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py new file mode 100644 index 000000000..2c77eca4c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_factor_not_available_message, + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class UpdateTenantFirstFactorOkResult(APIResponse): + status: Literal["OK"] = "OK" + + def __init__(self): + self.status = "OK" + + def to_json(self) -> Dict[str, Literal["OK"]]: + return {"status": self.status} + + +class UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult(APIResponse): + status: Literal[ + "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + ] = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + + def __init__(self, message: str): + self.status = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + self.message = message + + def to_json( + self, + ) -> Dict[str, Union[Literal["RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +class UpdateTenantFirstFactorUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +async def update_tenant_first_factor( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantFirstFactorOkResult, + UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult, + UpdateTenantFirstFactorUnknownTenantErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + factor_id = request_body["factorId"] + enable = request_body["enable"] + + mt_recipe = MultitenancyRecipe.get_instance() + + if enable is True: + if factor_id not in mt_recipe.all_available_first_factors: + return UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult( + message=get_factor_not_available_message( + factor_id, mt_recipe.all_available_first_factors + ) + ) + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + + if tenant_res is None: + return UpdateTenantFirstFactorUnknownTenantErrorResult() + + first_factors = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + ) + + if enable is True: + if factor_id not in first_factors: + first_factors.append(factor_id) + else: + first_factors = [f for f in first_factors if f != factor_id] + + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + ), + user_context=user_context, + ) + + return UpdateTenantFirstFactorOkResult() diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py new file mode 100644 index 000000000..ad4947754 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_factor_not_available_message, + get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class UpdateTenantSecondaryFactorOkResult(APIResponse): + status: Literal["OK"] = "OK" + is_mfa_requirements_for_auth_overridden: bool + + def __init__(self, is_mfa_requirements_for_auth_overridden: bool): + self.status = "OK" + self.is_mfa_requirements_for_auth_overridden = ( + is_mfa_requirements_for_auth_overridden + ) + + def to_json(self) -> Dict[str, Union[Literal["OK"], bool]]: + return { + "status": self.status, + "isMFARequirementsForAuthOverridden": self.is_mfa_requirements_for_auth_overridden, + } + + +class UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult( + APIResponse +): + status: Literal[ + "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + ] = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + + def __init__(self, message: str): + self.status = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + self.message = message + + def to_json( + self, + ) -> Dict[str, Union[Literal["RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +class UpdateTenantSecondaryFactorMfaNotInitializedErrorResult(APIResponse): + status: Literal["MFA_NOT_INITIALIZED_ERROR"] = "MFA_NOT_INITIALIZED_ERROR" + + def __init__(self): + self.status = "MFA_NOT_INITIALIZED_ERROR" + + def to_json(self) -> Dict[str, Literal["MFA_NOT_INITIALIZED_ERROR"]]: + return {"status": self.status} + + +class UpdateTenantSecondaryFactorUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +async def update_tenant_secondary_factor( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantSecondaryFactorOkResult, + UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult, + UpdateTenantSecondaryFactorMfaNotInitializedErrorResult, + UpdateTenantSecondaryFactorUnknownTenantErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + factor_id = request_body["factorId"] + enable = request_body["enable"] + + mt_recipe = MultitenancyRecipe.get_instance() + mfa_instance = MultiFactorAuthRecipe.get_instance() + + if mfa_instance is None: + return UpdateTenantSecondaryFactorMfaNotInitializedErrorResult() + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + + if tenant_res is None: + return UpdateTenantSecondaryFactorUnknownTenantErrorResult() + + if enable is True: + all_available_secondary_factors = ( + await mfa_instance.get_all_available_secondary_factor_ids(tenant_res) + ) + + if factor_id not in all_available_secondary_factors: + return ( + UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult( + message=get_factor_not_available_message( + factor_id, all_available_secondary_factors + ) + ) + ) + + secondary_factors = await get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + + if enable is True: + if factor_id not in secondary_factors: + secondary_factors.append(factor_id) + else: + secondary_factors = [f for f in secondary_factors if f != factor_id] + + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + required_secondary_factors=secondary_factors if secondary_factors else None, + ), + user_context=user_context, + ) + + return UpdateTenantSecondaryFactorOkResult( + is_mfa_requirements_for_auth_overridden=mfa_instance.is_get_mfa_requirements_for_auth_overridden + ) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/utils.py b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py index 48772d25d..9782cea50 100644 --- a/supertokens_python/recipe/dashboard/api/multitenancy/utils.py +++ b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py @@ -1,4 +1,6 @@ from typing import List +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.multifactorauth.utils import ( is_factor_configured_for_tenant, ) @@ -38,3 +40,63 @@ def get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( valid_first_factors.append(factor_id) return valid_first_factors + + +def get_factor_not_available_message( + factor_id: str, available_factors: List[str] +) -> str: + recipe_name = factor_id_to_recipe(factor_id) + if recipe_name != "Passwordless": + return f"Please initialise {recipe_name} recipe to be able to use this login method" + + passwordless_factors = [ + FactorIds.LINK_EMAIL, + FactorIds.LINK_PHONE, + FactorIds.OTP_EMAIL, + FactorIds.OTP_PHONE, + ] + passwordless_factors_not_available = [ + f for f in passwordless_factors if f not in available_factors + ] + + if len(passwordless_factors_not_available) == 4: + return ( + "Please initialise Passwordless recipe to be able to use this login method" + ) + + flow_type, contact_method = factor_id.split("-") + return f"Please ensure that Passwordless recipe is initialised with contactMethod: {contact_method.upper()} and flowType: {'USER_INPUT_CODE' if flow_type == 'otp' else 'MAGIC_LINK'}" + + +def factor_id_to_recipe(factor_id: str) -> str: + factor_id_to_recipe_map = { + "emailpassword": "Emailpassword", + "thirdparty": "ThirdParty", + "otp-email": "Passwordless", + "otp-phone": "Passwordless", + "link-email": "Passwordless", + "link-phone": "Passwordless", + "totp": "Totp", + } + + return factor_id_to_recipe_map.get(factor_id, "") + + +async def get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_details_from_core: TenantConfig, +) -> List[str]: + mfa_instance = MultiFactorAuthRecipe.get_instance() + + if mfa_instance is None: + return [] + + secondary_factors = await mfa_instance.get_all_available_secondary_factor_ids( + tenant_details_from_core + ) + secondary_factors = [ + factor_id + for factor_id in secondary_factors + if factor_id in (tenant_details_from_core.required_secondary_factors or []) + ] + + return secondary_factors diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 05aed583e..665784caa 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -15,3 +15,7 @@ TENANTS_LIST_API = "/api/tenants/list" TENANT_THIRD_PARTY_CONFIG_API = "/api/thirdparty/config" TENANT_API = "/api/tenant" +LIST_TENANTS_WITH_LOGIN_METHODS = "/api/tenants" +UPDATE_TENANT_CORE_CONFIG_API = "/api/tenant/core-config" +UPDATE_TENANT_FIRST_FACTOR_API = "/api/tenant/first-factor" +UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API = "/api/tenant/required-secondary-factor" diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index fc7761f9f..67bd2c839 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -35,6 +35,18 @@ from supertokens_python.recipe.dashboard.api.multitenancy.get_third_party_config import ( get_third_party_config, ) +from supertokens_python.recipe.dashboard.api.multitenancy.list_all_tenants_with_login_methods import ( + list_all_tenants_with_login_methods, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_core_config import ( + update_tenant_core_config, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_first_factor import ( + update_tenant_first_factor, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_secondary_factor import ( + update_tenant_secondary_factor, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -92,6 +104,10 @@ TENANTS_LIST_API, TENANT_THIRD_PARTY_CONFIG_API, TENANT_API, + LIST_TENANTS_WITH_LOGIN_METHODS, + UPDATE_TENANT_CORE_CONFIG_API, + UPDATE_TENANT_FIRST_FACTOR_API, + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, ) from .utils import ( InputOverrideConfig, @@ -314,6 +330,40 @@ def get_apis_handled(self) -> List[APIHandled]: TENANT_THIRD_PARTY_CONFIG_API, False, ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(LIST_TENANTS_WITH_LOGIN_METHODS) + ), + "get", + LIST_TENANTS_WITH_LOGIN_METHODS, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(UPDATE_TENANT_CORE_CONFIG_API) + ), + "put", + UPDATE_TENANT_CORE_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(UPDATE_TENANT_FIRST_FACTOR_API) + ), + "put", + UPDATE_TENANT_FIRST_FACTOR_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base( + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API + ) + ), + "put", + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, + False, + ), ] async def handle_api_request( @@ -407,6 +457,14 @@ async def handle_api_request( api_function = delete_third_party_config_api if method == "get": api_function = get_third_party_config + elif request_id == LIST_TENANTS_WITH_LOGIN_METHODS: + api_function = list_all_tenants_with_login_methods + elif request_id == UPDATE_TENANT_CORE_CONFIG_API: + api_function = update_tenant_core_config + elif request_id == UPDATE_TENANT_FIRST_FACTOR_API: + api_function = update_tenant_first_factor + elif request_id == UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API: + api_function = update_tenant_secondary_factor if api_function is not None: return await api_key_protector( From 9a4fa380856240e6296300e0eb50c22daaac51d9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 13:44:26 +0530 Subject: [PATCH 037/126] more apis --- .../api/user/create/emailpassword_user.py | 128 +++++++++++++++ .../api/user/create/passwordless_user.py | 150 ++++++++++++++++++ .../recipe/dashboard/constants.py | 2 + supertokens_python/recipe/dashboard/recipe.py | 28 ++++ .../recipe/passwordless/asyncio/__init__.py | 2 +- .../recipe/passwordless/syncio/__init__.py | 2 +- 6 files changed, 310 insertions(+), 2 deletions(-) create mode 100644 supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py create mode 100644 supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py diff --git a/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py new file mode 100644 index 000000000..35717cdca --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, Union + +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.emailpassword.asyncio import sign_up +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + SignUpOkResult, +) +from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe +from supertokens_python.types import APIResponse, AccountLinkingUser, RecipeUserId + + +class CreateEmailPasswordUserOkResponse(APIResponse): + def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): + self.status = "OK" + self.user = user + self.recipe_user_id = recipe_user_id + + def to_json(self): + return { + "status": self.status, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } + + +class CreateEmailPasswordUserFeatureNotEnabledResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreateEmailPasswordUserEmailAlreadyExistsResponse(APIResponse): + def __init__(self): + self.status = "EMAIL_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreateEmailPasswordUserEmailValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status = "EMAIL_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +class CreateEmailPasswordUserPasswordValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status = "PASSWORD_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +async def create_email_password_user( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateEmailPasswordUserOkResponse, + CreateEmailPasswordUserFeatureNotEnabledResponse, + CreateEmailPasswordUserEmailAlreadyExistsResponse, + CreateEmailPasswordUserEmailValidationErrorResponse, + CreateEmailPasswordUserPasswordValidationErrorResponse, +]: + email_password: EmailPasswordRecipe + try: + email_password = EmailPasswordRecipe.get_instance() + except Exception: + return CreateEmailPasswordUserFeatureNotEnabledResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is missing") + + email = request_body.get("email") + password = request_body.get("password") + + if not isinstance(email, str): + raise BadInputError( + "Required parameter 'email' is missing or has an invalid type" + ) + + if not isinstance(password, str): + raise BadInputError( + "Required parameter 'password' is missing or has an invalid type" + ) + + email_form_field = next( + field + for field in email_password.config.sign_up_feature.form_fields + if field.id == "email" + ) + validate_email_error = await email_form_field.validate(email, tenant_id) + + if validate_email_error is not None: + return CreateEmailPasswordUserEmailValidationErrorResponse(validate_email_error) + + password_form_field = next( + field + for field in email_password.config.sign_up_feature.form_fields + if field.id == "password" + ) + validate_password_error = await password_form_field.validate(password, tenant_id) + + if validate_password_error is not None: + return CreateEmailPasswordUserPasswordValidationErrorResponse( + validate_password_error + ) + + response = await sign_up(tenant_id, email, password) + + if isinstance(response, SignUpOkResult): + return CreateEmailPasswordUserOkResponse(response.user, response.recipe_user_id) + elif isinstance(response, EmailAlreadyExistsError): + return CreateEmailPasswordUserEmailAlreadyExistsResponse() + else: + raise Exception( + "This should never happen: EmailPassword.sign_up threw a session user related error without passing a session" + ) diff --git a/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py new file mode 100644 index 000000000..51bed90b7 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal + +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.passwordless import ( + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, +) +from supertokens_python.recipe.passwordless.asyncio import signinup +from supertokens_python.recipe.passwordless.recipe import PasswordlessRecipe +from supertokens_python.types import APIResponse, AccountLinkingUser, RecipeUserId +from phonenumbers import parse as parse_phone_number, format_number, PhoneNumberFormat + + +class CreatePasswordlessUserOkResponse(APIResponse): + def __init__( + self, + created_new_recipe_user: bool, + user: AccountLinkingUser, + recipe_user_id: RecipeUserId, + ): + self.status: Literal["OK"] = "OK" + self.created_new_recipe_user = created_new_recipe_user + self.user = user + self.recipe_user_id = recipe_user_id + + def to_json(self): + return { + "status": self.status, + "createdNewRecipeUser": self.created_new_recipe_user, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } + + +class CreatePasswordlessUserFeatureNotEnabledResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreatePasswordlessUserEmailValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status: Literal["EMAIL_VALIDATION_ERROR"] = "EMAIL_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +class CreatePasswordlessUserPhoneValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status: Literal["PHONE_VALIDATION_ERROR"] = "PHONE_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +async def create_passwordless_user( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + __: Dict[str, Any], +) -> Union[ + CreatePasswordlessUserOkResponse, + CreatePasswordlessUserFeatureNotEnabledResponse, + CreatePasswordlessUserEmailValidationErrorResponse, + CreatePasswordlessUserPhoneValidationErrorResponse, +]: + passwordless_recipe: PasswordlessRecipe + try: + passwordless_recipe = PasswordlessRecipe.get_instance() + except Exception: + return CreatePasswordlessUserFeatureNotEnabledResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is missing") + + email = request_body.get("email") + phone_number = request_body.get("phoneNumber") + + if (email is not None and phone_number is not None) or ( + email is None and phone_number is None + ): + raise BadInputError("Please provide exactly one of email or phoneNumber") + + if email is not None and ( + isinstance(passwordless_recipe.config.contact_config, ContactEmailOnlyConfig) + or isinstance( + passwordless_recipe.config.contact_config, ContactEmailOrPhoneConfig + ) + ): + email = email.strip() + validation_error = ( + await passwordless_recipe.config.contact_config.validate_email_address( + email, tenant_id + ) + ) + if validation_error is not None: + return CreatePasswordlessUserEmailValidationErrorResponse(validation_error) + + if phone_number is not None and ( + isinstance(passwordless_recipe.config.contact_config, ContactPhoneOnlyConfig) + or isinstance( + passwordless_recipe.config.contact_config, ContactEmailOrPhoneConfig + ) + ): + validation_error = ( + await passwordless_recipe.config.contact_config.validate_phone_number( + phone_number, tenant_id + ) + ) + if validation_error is not None: + return CreatePasswordlessUserPhoneValidationErrorResponse(validation_error) + + try: + parsed_phone_number = parse_phone_number(phone_number) + phone_number = format_number(parsed_phone_number, PhoneNumberFormat.E164) + except Exception: + # This can happen if the user has provided their own impl of validate_phone_number + # and the phone number is valid according to their impl, but not according to the phonenumbers lib. + phone_number = phone_number.strip() + + response = await signinup(tenant_id, email=email, phone_number=phone_number) + + return CreatePasswordlessUserOkResponse( + created_new_recipe_user=response.created_new_recipe_user, + user=response.user, + recipe_user_id=response.recipe_user_id, + ) diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 665784caa..92c80470e 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -19,3 +19,5 @@ UPDATE_TENANT_CORE_CONFIG_API = "/api/tenant/core-config" UPDATE_TENANT_FIRST_FACTOR_API = "/api/tenant/first-factor" UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API = "/api/tenant/required-secondary-factor" +CREATE_EMAIL_PASSWORD_USER = "/api/user/emailpassword" +CREATE_PASSWORDLESS_USER = "/api/user/passwordless" diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 67bd2c839..b510a8e6b 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -47,6 +47,12 @@ from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_secondary_factor import ( update_tenant_secondary_factor, ) +from supertokens_python.recipe.dashboard.api.user.create.emailpassword_user import ( + create_email_password_user, +) +from supertokens_python.recipe.dashboard.api.user.create.passwordless_user import ( + create_passwordless_user, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -108,6 +114,8 @@ UPDATE_TENANT_CORE_CONFIG_API, UPDATE_TENANT_FIRST_FACTOR_API, UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, + CREATE_EMAIL_PASSWORD_USER, + CREATE_PASSWORDLESS_USER, ) from .utils import ( InputOverrideConfig, @@ -364,6 +372,22 @@ def get_apis_handled(self) -> List[APIHandled]: UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, False, ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(CREATE_EMAIL_PASSWORD_USER) + ), + "post", + CREATE_EMAIL_PASSWORD_USER, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(CREATE_PASSWORDLESS_USER) + ), + "post", + CREATE_PASSWORDLESS_USER, + False, + ), ] async def handle_api_request( @@ -465,6 +489,10 @@ async def handle_api_request( api_function = update_tenant_first_factor elif request_id == UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API: api_function = update_tenant_secondary_factor + elif request_id == CREATE_EMAIL_PASSWORD_USER: + api_function = create_email_password_user + elif request_id == CREATE_PASSWORDLESS_USER: + api_function = create_passwordless_user if api_function is not None: return await api_key_protector( diff --git a/supertokens_python/recipe/passwordless/asyncio/__init__.py b/supertokens_python/recipe/passwordless/asyncio/__init__.py index d1f21e71f..0c977872b 100644 --- a/supertokens_python/recipe/passwordless/asyncio/__init__.py +++ b/supertokens_python/recipe/passwordless/asyncio/__init__.py @@ -287,7 +287,7 @@ async def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], - session: Optional[SessionContainer], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: if user_context is None: diff --git a/supertokens_python/recipe/passwordless/syncio/__init__.py b/supertokens_python/recipe/passwordless/syncio/__init__.py index 7f583a389..e5a0105a4 100644 --- a/supertokens_python/recipe/passwordless/syncio/__init__.py +++ b/supertokens_python/recipe/passwordless/syncio/__init__.py @@ -272,7 +272,7 @@ def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], - session: Optional[SessionContainer], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: if user_context is None: From 06e5d86981d7face2a2b20df4d930091cc8c949e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 14:19:52 +0530 Subject: [PATCH 038/126] more apis --- .../api/userdetails/user_email_verify_get.py | 8 +-- .../api/userdetails/user_email_verify_put.py | 10 +-- .../user_email_verify_token_post.py | 16 +++-- .../dashboard/api/userdetails/user_get.py | 44 ++++-------- .../api/userdetails/user_password_put.py | 68 ++++++------------- .../dashboard/api/userdetails/user_put.py | 4 -- .../recipe/dashboard/interfaces.py | 11 +-- supertokens_python/recipe/dashboard/utils.py | 31 --------- 8 files changed, 57 insertions(+), 135 deletions(-) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py index 1b8ff544f..2462b9ffa 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py @@ -19,10 +19,10 @@ async def handle_user_email_verify_get( user_context: Dict[str, Any], ) -> Union[UserEmailVerifyGetAPIResponse, FeatureNotEnabledError]: req = api_options.request - user_id = req.get_query_param("userId") + recipe_user_id = req.get_query_param("recipeUserId") - if user_id is None: - raise_bad_input_exception("Missing required parameter 'userId'") + if recipe_user_id is None: + raise_bad_input_exception("Missing required parameter 'recipeUserId'") try: EmailVerificationRecipe.get_instance_or_throw() @@ -30,6 +30,6 @@ async def handle_user_email_verify_get( return FeatureNotEnabledError() is_verified = await is_email_verified( - RecipeUserId(user_id), user_context=user_context + RecipeUserId(recipe_user_id), user_context=user_context ) return UserEmailVerifyGetAPIResponse(is_verified) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py index e7a1e4700..56206839d 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py @@ -27,12 +27,12 @@ async def handle_user_email_verify_put( user_context: Dict[str, Any], ) -> UserEmailVerifyPutAPIResponse: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") verified = request_body.get("verified") - if user_id is None or not isinstance(user_id, str): + if recipe_user_id is None or not isinstance(recipe_user_id, str): raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + "Required parameter 'recipeUserId' is missing or has an invalid type" ) if verified is None or not isinstance(verified, bool): @@ -43,7 +43,7 @@ async def handle_user_email_verify_put( if verified: token_response = await create_email_verification_token( tenant_id=tenant_id, - recipe_user_id=RecipeUserId(user_id), + recipe_user_id=RecipeUserId(recipe_user_id), email=None, user_context=user_context, ) @@ -62,6 +62,6 @@ async def handle_user_email_verify_put( raise Exception("Should not come here") else: - await unverify_email(RecipeUserId(user_id), user_context=user_context) + await unverify_email(RecipeUserId(recipe_user_id), user_context=user_context) return UserEmailVerifyPutAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py index 1392915c1..8707053d6 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Union +from supertokens_python.asyncio import get_user from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailverification.asyncio import ( @@ -26,17 +27,22 @@ async def handle_email_verify_token_post( UserEmailVerifyTokenPostAPIEmailAlreadyVerifiedErrorResponse, ]: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") - if user_id is None or not isinstance(user_id, str): + if recipe_user_id is None or not isinstance(recipe_user_id, str): raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + "Required parameter 'recipeUserId' is missing or has an invalid type" ) + user = await get_user(recipe_user_id, user_context) + + if user is None: + raise_bad_input_exception("User not found") + res = await send_email_verification_email( tenant_id=tenant_id, - user_id=user_id, - recipe_user_id=RecipeUserId(user_id), + user_id=user.id, + recipe_user_id=RecipeUserId(recipe_user_id), email=None, user_context=user_context, ) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py index f4559f48d..9d0dcef3e 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py @@ -1,19 +1,19 @@ from typing import Union, Dict, Any +from supertokens_python.asyncio import get_user from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.recipe.dashboard.utils import get_user_for_recipe_id +from supertokens_python.recipe.dashboard.utils import ( + UserWithMetadata, +) from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.usermetadata.asyncio import get_user_metadata -from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, APIOptions, UserGetAPINoUserFoundError, UserGetAPIOkResponse, - UserGetAPIRecipeNotInitialisedError, ) -from ...utils import is_recipe_initialised, is_valid_recipe_id async def handle_user_get( @@ -21,47 +21,31 @@ async def handle_user_get( _tenant_id: str, api_options: APIOptions, _user_context: Dict[str, Any], -) -> Union[ - UserGetAPINoUserFoundError, - UserGetAPIOkResponse, - UserGetAPIRecipeNotInitialisedError, -]: +) -> Union[UserGetAPINoUserFoundError, UserGetAPIOkResponse,]: user_id = api_options.request.get_query_param("userId") - recipe_id = api_options.request.get_query_param("recipeId") if user_id is None: raise_bad_input_exception("Missing required parameter 'userId'") - if recipe_id is None: - raise_bad_input_exception("Missing required parameter 'recipeId'") - - if not is_valid_recipe_id(recipe_id): - raise_bad_input_exception("Invalid recipe id") - - if not is_recipe_initialised(recipe_id): - return UserGetAPIRecipeNotInitialisedError() - - user_response = await get_user_for_recipe_id( - RecipeUserId(user_id), recipe_id, _user_context - ) - if user_response.user is None: + user_response = await get_user(user_id, _user_context) + if user_response is None: return UserGetAPINoUserFoundError() - user = user_response.user + user_with_metadata: UserWithMetadata = UserWithMetadata().from_user(user_response) try: UserMetadataRecipe.get_instance() except Exception: - user.first_name = "FEATURE_NOT_ENABLED" - user.last_name = "FEATURE_NOT_ENABLED" + user_with_metadata.first_name = "FEATURE_NOT_ENABLED" + user_with_metadata.last_name = "FEATURE_NOT_ENABLED" - return UserGetAPIOkResponse(recipe_id, user) + return UserGetAPIOkResponse(user_with_metadata) user_metadata = await get_user_metadata(user_id, user_context=_user_context) first_name = user_metadata.metadata.get("first_name", "") last_name = user_metadata.metadata.get("last_name", "") - user.first_name = first_name - user.last_name = last_name + user_with_metadata.first_name = first_name + user_with_metadata.last_name = last_name - return UserGetAPIOkResponse(recipe_id, user) + return UserGetAPIOkResponse(user_with_metadata) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py index bc72c61bb..1ed63365c 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py @@ -1,17 +1,12 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailpassword import EmailPasswordRecipe -from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_PASSWORD_ID from supertokens_python.recipe.emailpassword.interfaces import ( + PasswordPolicyViolationError, UnknownUserIdError, - PasswordResetTokenInvalidError, ) -from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_token, - reset_password_using_token, -) -from supertokens_python.recipe.emailpassword.types import NormalisedFormField +from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, @@ -28,52 +23,33 @@ async def handle_user_password_put( user_context: Dict[str, Any], ) -> Union[UserPasswordPutAPIResponse, UserPasswordPutAPIInvalidPasswordErrorResponse]: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") new_password = request_body.get("newPassword") - if user_id is None or not isinstance(user_id, str): - raise_bad_input_exception("Missing required parameter 'userId'") + if recipe_user_id is None or not isinstance(recipe_user_id, str): + raise_bad_input_exception("Missing required parameter 'recipeUserId'") if new_password is None or not isinstance(new_password, str): raise_bad_input_exception("Missing required parameter 'newPassword'") - async def reset_password( - form_fields: List[NormalisedFormField], - ) -> Union[ - UserPasswordPutAPIResponse, UserPasswordPutAPIInvalidPasswordErrorResponse - ]: - password_form_field = [ - field for field in form_fields if field.id == FORM_FIELD_PASSWORD_ID - ][0] - - password_validation_error = await password_form_field.validate( - new_password, tenant_id - ) - - if password_validation_error is not None: - return UserPasswordPutAPIInvalidPasswordErrorResponse( - password_validation_error - ) - - password_reset_token = await create_reset_password_token( - tenant_id, user_id, "", user_context + email_password_recipe = EmailPasswordRecipe.get_instance() + update_response = ( + await email_password_recipe.recipe_implementation.update_email_or_password( + recipe_user_id=RecipeUserId(recipe_user_id), + email=None, + password=new_password, + apply_password_policy=True, + tenant_id_for_password_policy=tenant_id, + user_context=user_context, ) + ) - if isinstance(password_reset_token, UnknownUserIdError): - # Techincally it can but its an edge case so we assume that it wont - # UNKNOWN_USER_ID_ERROR - raise Exception("Should never come here") - - password_reset_response = await reset_password_using_token( - tenant_id, password_reset_token.token, new_password, user_context + if isinstance(update_response, PasswordPolicyViolationError): + return UserPasswordPutAPIInvalidPasswordErrorResponse( + error=update_response.failure_reason ) - if isinstance(password_reset_response, PasswordResetTokenInvalidError): - # RESET_PASSWORD_INVALID_TOKEN_ERROR - raise Exception("Should not come here") + if isinstance(update_response, UnknownUserIdError): + raise Exception("Should never come here") - return UserPasswordPutAPIResponse() - - return await reset_password( - EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields, - ) + return UserPasswordPutAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index 537e6e9b3..ea0c7f34b 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -3,7 +3,6 @@ from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.dashboard.utils import ( get_user_for_recipe_id, - is_valid_recipe_id, ) from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.asyncio import ( @@ -195,9 +194,6 @@ async def handle_user_put( "Required parameter 'recipeId' is missing or has an invalid type" ) - if not is_valid_recipe_id(recipe_id): - raise_bad_input_exception("Invalid recipe id") - if first_name is None and not isinstance(first_name, str): raise_bad_input_exception( "Required parameter 'firstName' is missing or has an invalid type" diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 9c60692fa..c4c2aa512 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -140,14 +140,12 @@ def to_json(self) -> Dict[str, Any]: class UserGetAPIOkResponse(APIResponse): status: str = "OK" - def __init__(self, recipe_id: str, user: UserWithMetadata): - self.recipe_id = recipe_id + def __init__(self, user: UserWithMetadata): self.user = user def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "recipeId": self.recipe_id, "user": self.user.to_json(), } @@ -159,13 +157,6 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class UserGetAPIRecipeNotInitialisedError(APIResponse): - status: str = "RECIPE_NOT_INITIALISED" - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} - - class FeatureNotEnabledError(APIResponse): status: str = "FEATURE_NOT_ENABLED_ERROR" diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 310b3c098..7e053dbdf 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -182,10 +182,6 @@ def get_api_if_matched(path: NormalisedURLPath, method: str) -> Optional[str]: return None -def is_valid_recipe_id(recipe_id: str) -> bool: - return recipe_id in ("emailpassword", "thirdparty", "passwordless") - - class GetUserForRecipeIdHelperResult: def __init__( self, user: Optional[AccountLinkingUser] = None, recipe: Optional[str] = None @@ -265,33 +261,6 @@ async def _get_user_for_recipe_id( return GetUserForRecipeIdHelperResult(user=user, recipe=recipe) -def is_recipe_initialised(recipeId: str) -> bool: - isRecipeInitialised: bool = False - - if recipeId == EmailPasswordRecipe.recipe_id: - try: - EmailPasswordRecipe.get_instance() - isRecipeInitialised = True - except Exception: - pass - - elif recipeId == PasswordlessRecipe.recipe_id: - try: - PasswordlessRecipe.get_instance() - isRecipeInitialised = True - except Exception: - pass - - elif recipeId == ThirdPartyRecipe.recipe_id: - try: - ThirdPartyRecipe.get_instance() - isRecipeInitialised = True - except Exception: - pass - - return isRecipeInitialised - - async def validate_api_key( req: BaseRequest, config: DashboardConfig, _user_context: Dict[str, Any] ) -> bool: From 0b08858f9c3e0e8744042e506e31c252eba8dac9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 14:50:24 +0530 Subject: [PATCH 039/126] more apis --- .../dashboard/api/userdetails/user_put.py | 339 +++++++++++------- .../api/userdetails/user_sessions_get.py | 4 +- .../api/userdetails/user_unlink_get.py | 31 ++ .../recipe/dashboard/constants.py | 1 + .../recipe/dashboard/interfaces.py | 1 - supertokens_python/recipe/dashboard/recipe.py | 12 + 6 files changed, 256 insertions(+), 132 deletions(-) create mode 100644 supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index ea0c7f34b..ff8b47215 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union +from typing_extensions import Literal from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.dashboard.utils import ( @@ -11,24 +12,25 @@ from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_EMAIL_ID from supertokens_python.recipe.emailpassword.interfaces import ( EmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, +) +from supertokens_python.recipe.passwordless import ( + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, + PasswordlessRecipe, ) -from supertokens_python.recipe.passwordless import PasswordlessRecipe from supertokens_python.recipe.passwordless.asyncio import ( update_user as pless_update_user, ) from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserEmailAlreadyExistsError as EmailAlreadyExistsErrorResponse, -) -from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserPhoneNumberAlreadyExistsError as PhoneNumberAlreadyExistsError, -) -from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserUnknownUserIdError as PlessUpdateUserUnknownUserIdError, + UpdateUserEmailAlreadyExistsError, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ) from supertokens_python.recipe.passwordless.utils import ( - ContactEmailOnlyConfig, - ContactEmailOrPhoneConfig, - ContactPhoneOnlyConfig, default_validate_email, default_validate_phone_number, ) @@ -39,129 +41,211 @@ from ...interfaces import ( APIInterface, APIOptions, - UserPutAPIEmailAlreadyExistsErrorResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutAPIOkResponse, - UserPutPhoneAlreadyExistsAPIResponse, + APIResponse, ) +class OkResponse(APIResponse): + status: Literal["OK"] + + def __init__(self): + self.status = "OK" + + def to_json(self): + return {"status": self.status} + + +class EmailAlreadyExistsErrorResponse(APIResponse): + status: Literal["EMAIL_ALREADY_EXISTS_ERROR"] + + def __init__(self): + self.status = "EMAIL_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class InvalidEmailErrorResponse(APIResponse): + status: Literal["INVALID_EMAIL_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "INVALID_EMAIL_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class PhoneAlreadyExistsErrorResponse(APIResponse): + status: Literal["PHONE_ALREADY_EXISTS_ERROR"] + + def __init__(self): + self.status = "PHONE_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class InvalidPhoneErrorResponse(APIResponse): + status: Literal["INVALID_PHONE_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "INVALID_PHONE_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class EmailChangeNotAllowedErrorResponse(APIResponse): + status: Literal["EMAIL_CHANGE_NOT_ALLOWED_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "EMAIL_CHANGE_NOT_ALLOWED_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class PhoneNumberChangeNotAllowedErrorResponse(APIResponse): + status: Literal["PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + async def update_email_for_recipe_id( recipe_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: str, tenant_id: str, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIEmailAlreadyExistsErrorResponse, + OkResponse, + InvalidEmailErrorResponse, + EmailAlreadyExistsErrorResponse, + EmailChangeNotAllowedErrorResponse, ]: - validation_error: Optional[str] = None - if recipe_id == "emailpassword": - form_fields = ( - EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields - ) email_form_fields = [ - form_field - for form_field in form_fields - if form_field.id == FORM_FIELD_EMAIL_ID + field + for field in EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields + if field.id == FORM_FIELD_EMAIL_ID ] validation_error = await email_form_fields[0].validate(email, tenant_id) if validation_error is not None: - return UserPutAPIInvalidEmailErrorResponse(validation_error) + return InvalidEmailErrorResponse(validation_error) email_update_response = await ep_update_email_or_password( - RecipeUserId(user_id), email, user_context=user_context + recipe_user_id, email=email, user_context=user_context ) if isinstance(email_update_response, EmailAlreadyExistsError): - return UserPutAPIEmailAlreadyExistsErrorResponse() + return EmailAlreadyExistsErrorResponse() + elif isinstance( + email_update_response, UpdateEmailOrPasswordEmailChangeNotAllowedError + ): + return EmailChangeNotAllowedErrorResponse(email_update_response.reason) - return UserPutAPIOkResponse() + return OkResponse() if recipe_id == "passwordless": - validation_error = None + passwordless_config = PasswordlessRecipe.get_instance().config - passwordless_config = PasswordlessRecipe.get_instance().config.contact_config - - if isinstance(passwordless_config.contact_method, ContactPhoneOnlyConfig): + if isinstance(passwordless_config.contact_config, ContactPhoneOnlyConfig): validation_error = await default_validate_email(email, tenant_id) - - elif isinstance( - passwordless_config, (ContactEmailOnlyConfig, ContactEmailOrPhoneConfig) - ): - validation_error = await passwordless_config.validate_email_address( - email, tenant_id - ) + else: + if isinstance( + passwordless_config.contact_config, + (ContactEmailOnlyConfig, ContactEmailOrPhoneConfig), + ): + validation_error = ( + await passwordless_config.contact_config.validate_email_address( + email, tenant_id + ) + ) + else: + raise Exception("Should never come here") if validation_error is not None: - return UserPutAPIInvalidEmailErrorResponse(validation_error) + return InvalidEmailErrorResponse(validation_error) update_result = await pless_update_user( - RecipeUserId(user_id), email, user_context=user_context + recipe_user_id, email=email, user_context=user_context ) - if isinstance(update_result, PlessUpdateUserUnknownUserIdError): + if isinstance(update_result, UpdateUserUnknownUserIdError): raise Exception("Should never come here") + elif isinstance(update_result, UpdateUserEmailAlreadyExistsError): + return EmailAlreadyExistsErrorResponse() + elif isinstance( + update_result, + ( + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + ), + ): + return EmailChangeNotAllowedErrorResponse(update_result.reason) - if isinstance(update_result, EmailAlreadyExistsErrorResponse): - return UserPutAPIEmailAlreadyExistsErrorResponse() - - return UserPutAPIOkResponse() + return OkResponse() # If it comes here then the user is a third party user in which case the UI should not have allowed this raise Exception("Should never come here") async def update_phone_for_recipe_id( - recipe_id: str, - user_id: str, + recipe_user_id: RecipeUserId, phone: str, tenant_id: str, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutPhoneAlreadyExistsAPIResponse, + OkResponse, + InvalidPhoneErrorResponse, + PhoneAlreadyExistsErrorResponse, + PhoneNumberChangeNotAllowedErrorResponse, ]: - validation_error: Optional[str] = None - - if recipe_id == "passwordless": - validation_error = None - - passwordless_config = PasswordlessRecipe.get_instance().config.contact_config - - if isinstance(passwordless_config, ContactEmailOnlyConfig): - validation_error = await default_validate_phone_number(phone, tenant_id) - elif isinstance( - passwordless_config, (ContactPhoneOnlyConfig, ContactEmailOrPhoneConfig) - ): - validation_error = await passwordless_config.validate_phone_number( + passwordless_config = PasswordlessRecipe.get_instance().config + + if isinstance(passwordless_config.contact_config, ContactEmailOnlyConfig): + validation_error = await default_validate_phone_number(phone, tenant_id) + elif isinstance( + passwordless_config.contact_config, + (ContactPhoneOnlyConfig, ContactEmailOrPhoneConfig), + ): + validation_error = ( + await passwordless_config.contact_config.validate_phone_number( phone, tenant_id ) - - if validation_error is not None: - return UserPutAPIInvalidPhoneErrorResponse(validation_error) - - update_result = await pless_update_user( - RecipeUserId(user_id), phone_number=phone, user_context=user_context ) + else: + raise Exception("Invalid contact config") - if isinstance(update_result, PlessUpdateUserUnknownUserIdError): - raise Exception("Should never come here") + if validation_error is not None: + return InvalidPhoneErrorResponse(validation_error) - if isinstance(update_result, PhoneNumberAlreadyExistsError): - return UserPutPhoneAlreadyExistsAPIResponse() + update_result = await pless_update_user( + recipe_user_id, phone_number=phone, user_context=user_context + ) - return UserPutAPIOkResponse() + if isinstance(update_result, UpdateUserUnknownUserIdError): + raise Exception("Should never come here") + elif isinstance(update_result, UpdateUserPhoneNumberAlreadyExistsError): + return PhoneAlreadyExistsErrorResponse() + elif isinstance(update_result, PhoneNumberChangeNotAllowedError): + return PhoneNumberChangeNotAllowedErrorResponse(update_result.reason) - # If it comes here then the user is a third party user in which case the UI should not have allowed this - raise Exception("Should never come here") + return OkResponse() async def handle_user_put( @@ -170,65 +254,63 @@ async def handle_user_put( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIEmailAlreadyExistsErrorResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutPhoneAlreadyExistsAPIResponse, + OkResponse, + InvalidEmailErrorResponse, + EmailAlreadyExistsErrorResponse, + InvalidPhoneErrorResponse, + PhoneAlreadyExistsErrorResponse, + EmailChangeNotAllowedErrorResponse, + PhoneNumberChangeNotAllowedErrorResponse, ]: - request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id: Optional[str] = request_body.get("userId") - recipe_id: Optional[str] = request_body.get("recipeId") - first_name: Optional[str] = request_body.get("firstName") - last_name: Optional[str] = request_body.get("lastName") - email: Optional[str] = request_body.get("email") - phone: Optional[str] = request_body.get("phone") - - if not isinstance(user_id, str): - return raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + recipe_user_id = request_body.get("recipeUserId") + recipe_id = request_body.get("recipeId") + first_name = request_body.get("firstName") + last_name = request_body.get("lastName") + email = request_body.get("email") + phone = request_body.get("phone") + + if not isinstance(recipe_user_id, str): + raise_bad_input_exception( + "Required parameter 'recipeUserId' is missing or has an invalid type" ) if not isinstance(recipe_id, str): - return raise_bad_input_exception( + raise_bad_input_exception( "Required parameter 'recipeId' is missing or has an invalid type" ) - if first_name is None and not isinstance(first_name, str): + if not isinstance(first_name, str): raise_bad_input_exception( "Required parameter 'firstName' is missing or has an invalid type" ) - if last_name is None and not isinstance(last_name, str): + if not isinstance(last_name, str): raise_bad_input_exception( "Required parameter 'lastName' is missing or has an invalid type" ) - if email is None and not isinstance(email, str): + if not isinstance(email, str): raise_bad_input_exception( "Required parameter 'email' is missing or has an invalid type" ) - if phone is None and not isinstance(phone, str): + if not isinstance(phone, str): raise_bad_input_exception( "Required parameter 'phone' is missing or has an invalid type" ) user_response = await get_user_for_recipe_id( - RecipeUserId(user_id), recipe_id, user_context + RecipeUserId(recipe_user_id), recipe_id, user_context ) - if user_response.user is None: + if user_response.user is None or user_response.recipe is None: raise Exception("Should never come here") - first_name = first_name.strip() - last_name = last_name.strip() - email = email.strip() - phone = phone.strip() - - if first_name != "" or last_name != "": + if first_name.strip() or last_name.strip(): is_recipe_initialized = False - try: UserMetadataRecipe.get_instance() is_recipe_initialized = True @@ -238,36 +320,37 @@ async def handle_user_put( if is_recipe_initialized: metadata_update: Dict[str, Any] = {} - if first_name != "": - metadata_update["first_name"] = first_name + if first_name.strip(): + metadata_update["first_name"] = first_name.strip() - if last_name != "": - metadata_update["last_name"] = last_name + if last_name.strip(): + metadata_update["last_name"] = last_name.strip() - await update_user_metadata(user_id, metadata_update, user_context) + await update_user_metadata( + user_response.user.user.id, metadata_update, user_context + ) - if email != "": + if email.strip(): email_update_response = await update_email_for_recipe_id( - user_response.recipe or "passwordless", - user_id, - email, + user_response.recipe, + RecipeUserId(recipe_user_id), + email.strip(), tenant_id, user_context, ) - if not isinstance(email_update_response, UserPutAPIOkResponse): + if not isinstance(email_update_response, OkResponse): return email_update_response - if phone != "": + if phone.strip(): phone_update_response = await update_phone_for_recipe_id( - user_response.recipe or "passwordless", - user_id, - phone, + RecipeUserId(recipe_user_id), + phone.strip(), tenant_id, user_context, ) - if not isinstance(phone_update_response, UserPutAPIOkResponse): + if not isinstance(phone_update_response, OkResponse): return phone_update_response - return UserPutAPIOkResponse() + return OkResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py index 5dc815efd..bea9de4bd 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py @@ -28,9 +28,7 @@ async def handle_sessions_get( # Passing tenant id as None sets fetch_across_all_tenants to True # which is what we want here. - session_handles = await get_all_session_handles_for_user( - user_id, True, None, user_context - ) + session_handles = await get_all_session_handles_for_user(user_id) sessions: List[Optional[SessionInfo]] = [None for _ in session_handles] async def call_(i: int, session_handle: str): diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py new file mode 100644 index 000000000..2fef96569 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py @@ -0,0 +1,31 @@ +from typing import Any, Dict +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.accountlinking.asyncio import unlink_account +from supertokens_python.recipe.dashboard.utils import RecipeUserId +from ...interfaces import APIInterface, APIOptions +from supertokens_python.types import APIResponse + + +class UserUnlinkGetOkResult(APIResponse): + def __init__(self): + self.status: Literal["OK"] = "OK" + + def to_json(self): + return {"status": self.status} + + +async def handle_user_unlink_get( + _api_interface: APIInterface, + _tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> UserUnlinkGetOkResult: + recipe_user_id = api_options.request.get_query_param("recipeUserId") + + if recipe_user_id is None: + raise_bad_input_exception("Required field recipeUserId is missing") + + await unlink_account(RecipeUserId(recipe_user_id), user_context) + + return UserUnlinkGetOkResult() diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 92c80470e..f7ec1ccde 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -21,3 +21,4 @@ UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API = "/api/tenant/required-secondary-factor" CREATE_EMAIL_PASSWORD_USER = "/api/user/emailpassword" CREATE_PASSWORDLESS_USER = "/api/user/passwordless" +UNLINK_USER = "/api/user/unlink" diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index c4c2aa512..e77925c01 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -183,7 +183,6 @@ def __init__(self, sessions: List[SessionInfo]): "accessTokenPayload": s.access_token_payload, "expiry": s.expiry, "sessionDataInDatabase": s.session_data_in_database, - "tenantId": s.tenant_id, "timeCreated": s.time_created, "userId": s.user_id, "sessionHandle": s.session_handle, diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index b510a8e6b..769286892 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -53,6 +53,9 @@ from supertokens_python.recipe.dashboard.api.user.create.passwordless_user import ( create_passwordless_user, ) +from supertokens_python.recipe.dashboard.api.userdetails.user_unlink_get import ( + handle_user_unlink_get, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -116,6 +119,7 @@ UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, CREATE_EMAIL_PASSWORD_USER, CREATE_PASSWORDLESS_USER, + UNLINK_USER, ) from .utils import ( InputOverrideConfig, @@ -388,6 +392,12 @@ def get_apis_handled(self) -> List[APIHandled]: CREATE_PASSWORDLESS_USER, False, ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(UNLINK_USER)), + "get", + UNLINK_USER, + False, + ), ] async def handle_api_request( @@ -493,6 +503,8 @@ async def handle_api_request( api_function = create_email_password_user elif request_id == CREATE_PASSWORDLESS_USER: api_function = create_passwordless_user + elif request_id == UNLINK_USER: + api_function = handle_user_unlink_get if api_function is not None: return await api_key_protector( From 051c29f93347f5c874f1199226c024c5eaefe36d Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 17:01:01 +0530 Subject: [PATCH 040/126] more apis --- .../permissions/get_permissions_for_role.py | 64 ++++++++++++++++++ .../remove_permissions_from_role.py | 65 +++++++++++++++++++ .../recipe/dashboard/constants.py | 2 + supertokens_python/recipe/dashboard/recipe.py | 36 ++++++++++ 4 files changed, 167 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py diff --git a/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py new file mode 100644 index 000000000..670c98f5f --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_permissions_for_role +from supertokens_python.recipe.userroles.interfaces import ( + GetPermissionsForRoleOkResult, +) + +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkPermissionsForRoleResponse(APIResponse): + def __init__(self, permissions: list[str]): + self.status = "OK" + self.permissions = permissions + + def to_json(self): + return {"status": self.status, "permissions": self.permissions} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_permissions_for_role_api( + _api_interface: APIInterface, + _tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + OkPermissionsForRoleResponse, + FeatureNotEnabledErrorResponse, + UnknownRoleErrorResponse, +]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception("Required parameter 'role' is missing") + + response = await get_permissions_for_role(role) + + if isinstance(response, GetPermissionsForRoleOkResult): + return OkPermissionsForRoleResponse(response.permissions) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py b/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py new file mode 100644 index 000000000..c6a479311 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List, Union +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import remove_permissions_from_role +from supertokens_python.recipe.userroles.interfaces import ( + RemovePermissionsFromRoleOkResult, +) +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self): + self.status = "OK" + + def to_json(self): + return {"status": self.status} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def remove_permissions_from_role_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Dict[str, Any] +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + role = request_body.get("role") + permissions: Union[List[str], None] = request_body.get("permissions") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if permissions is None: + raise_bad_input_exception( + "Required parameter 'permissions' is missing or has an invalid type" + ) + + response = await remove_permissions_from_role(role, permissions) + if isinstance(response, RemovePermissionsFromRoleOkResult): + return OkResponse() + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index f7ec1ccde..71482e958 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -22,3 +22,5 @@ CREATE_EMAIL_PASSWORD_USER = "/api/user/emailpassword" CREATE_PASSWORDLESS_USER = "/api/user/passwordless" UNLINK_USER = "/api/user/unlink" +USERROLES_PERMISSIONS_API = "/api/userroles/role/permissions" +USERROLES_REMOVE_PERMISSIONS_API = "/api/userroles/role/permissions/remove" diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 769286892..be54cc9bd 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -56,6 +56,12 @@ from supertokens_python.recipe.dashboard.api.userdetails.user_unlink_get import ( handle_user_unlink_get, ) +from supertokens_python.recipe.dashboard.api.userroles.permissions.get_permissions_for_role import ( + get_permissions_for_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.permissions.remove_permissions_from_role import ( + remove_permissions_from_role_api, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -120,6 +126,8 @@ CREATE_EMAIL_PASSWORD_USER, CREATE_PASSWORDLESS_USER, UNLINK_USER, + USERROLES_PERMISSIONS_API, + USERROLES_REMOVE_PERMISSIONS_API, ) from .utils import ( InputOverrideConfig, @@ -398,6 +406,30 @@ def get_apis_handled(self) -> List[APIHandled]: UNLINK_USER, False, ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_PERMISSIONS_API) + ), + "get", + USERROLES_PERMISSIONS_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_PERMISSIONS_API) + ), + "put", + USERROLES_PERMISSIONS_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_REMOVE_PERMISSIONS_API) + ), + "put", + USERROLES_REMOVE_PERMISSIONS_API, + False, + ), ] async def handle_api_request( @@ -505,6 +537,10 @@ async def handle_api_request( api_function = create_passwordless_user elif request_id == UNLINK_USER: api_function = handle_user_unlink_get + elif request_id == USERROLES_PERMISSIONS_API: + api_function = get_permissions_for_role_api + elif request_id == USERROLES_REMOVE_PERMISSIONS_API: + api_function = remove_permissions_from_role_api if api_function is not None: return await api_key_protector( From 85b8083efbf070cb1964f6592a4bac87094d31bd Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 17:27:31 +0530 Subject: [PATCH 041/126] more apis --- .../api/userroles/add_role_to_user.py | 71 +++++++++++++++++++ .../api/userroles/get_role_to_user.py | 46 ++++++++++++ .../api/userroles/remove_user_role.py | 65 +++++++++++++++++ .../roles/create_role_or_add_permissions.py | 54 ++++++++++++++ .../api/userroles/roles/delete_role.py | 43 +++++++++++ .../api/userroles/roles/get_all_roles.py | 35 +++++++++ .../recipe/dashboard/constants.py | 2 + supertokens_python/recipe/dashboard/recipe.py | 70 ++++++++++++++++++ 8 files changed, 386 insertions(+) create mode 100644 supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py create mode 100644 supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py diff --git a/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py b/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py new file mode 100644 index 000000000..39f9b9d44 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py @@ -0,0 +1,71 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import add_role_to_user +from supertokens_python.recipe.userroles.interfaces import AddRoleToUserOkResult +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_user_already_have_role: bool): + self.status: Literal["OK"] = "OK" + self.did_user_already_have_role = did_user_already_have_role + + def to_json(self): + return { + "status": self.status, + "didUserAlreadyHaveRole": self.did_user_already_have_role, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_ROLE_ERROR"] = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def add_role_to_user_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + user_id = request_body.get("userId") + role = request_body.get("role") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if user_id is None or not isinstance(user_id, str): + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await add_role_to_user(tenant_id, user_id, role) + + if isinstance(response, AddRoleToUserOkResult): + return OkResponse( + did_user_already_have_role=response.did_user_already_have_role + ) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py new file mode 100644 index 000000000..2cbd79bbc --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py @@ -0,0 +1,46 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_roles_for_user +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, roles: list[str]): + self.status: Literal["OK"] = "OK" + self.roles = roles + + def to_json(self): + return { + "status": self.status, + "roles": self.roles, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_roles_for_user_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + user_id = api_options.request.get_query_param("userId") + + if user_id is None: + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await get_roles_for_user(tenant_id, user_id) + return OkResponse(roles=response.roles) diff --git a/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py b/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py new file mode 100644 index 000000000..b110d79de --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py @@ -0,0 +1,65 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import remove_user_role +from supertokens_python.recipe.userroles.interfaces import RemoveUserRoleOkResult +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_user_have_role: bool): + self.status: Literal["OK"] = "OK" + self.did_user_have_role = did_user_have_role + + def to_json(self): + return { + "status": self.status, + "didUserHaveRole": self.did_user_have_role, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_ROLE_ERROR"] = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def remove_user_role_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + user_id = api_options.request.get_query_param("userId") + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if user_id is None: + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await remove_user_role(tenant_id, user_id, role) + + if isinstance(response, RemoveUserRoleOkResult): + return OkResponse(did_user_have_role=response.did_user_have_role) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py b/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py new file mode 100644 index 000000000..e5a89b4fb --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py @@ -0,0 +1,54 @@ +from typing import Any, List, Union +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import ( + create_new_role_or_add_permissions, +) +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, created_new_role: bool): + self.status = "OK" + self.created_new_role = created_new_role + + def to_json(self): + return {"status": self.status, "createdNewRole": self.created_new_role} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def create_role_or_add_permissions_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + role = request_body.get("role") + permissions: Union[List[str], None] = request_body.get("permissions") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if permissions is None: + raise_bad_input_exception( + "Required parameter 'permissions' is missing or has an invalid type" + ) + + response = await create_new_role_or_add_permissions(role, permissions) + return OkResponse(created_new_role=response.created_new_role) diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py b/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py new file mode 100644 index 000000000..77816d72c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py @@ -0,0 +1,43 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import delete_role +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_role_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_role_exist = did_role_exist + + def to_json(self): + return {"status": self.status, "didRoleExist": self.did_role_exist} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def delete_role_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + response = await delete_role(role) + return OkResponse(did_role_exist=response.did_role_exist) diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py new file mode 100644 index 000000000..35fca6842 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py @@ -0,0 +1,35 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_all_roles +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, roles: list[str]): + self.status: Literal["OK"] = "OK" + self.roles = roles + + def to_json(self): + return {"status": self.status, "roles": self.roles} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_all_roles_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + response = await get_all_roles() + return OkResponse(roles=response.roles) diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 71482e958..10c127b59 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -24,3 +24,5 @@ UNLINK_USER = "/api/user/unlink" USERROLES_PERMISSIONS_API = "/api/userroles/role/permissions" USERROLES_REMOVE_PERMISSIONS_API = "/api/userroles/role/permissions/remove" +USERROLES_ROLE_API = "/api/userroles/role" +USERROLES_USER_API = "/api/userroles/user/roles" diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index be54cc9bd..b547593b8 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -56,12 +56,30 @@ from supertokens_python.recipe.dashboard.api.userdetails.user_unlink_get import ( handle_user_unlink_get, ) +from supertokens_python.recipe.dashboard.api.userroles.add_role_to_user import ( + add_role_to_user_api, +) +from supertokens_python.recipe.dashboard.api.userroles.get_role_to_user import ( + get_roles_for_user_api, +) from supertokens_python.recipe.dashboard.api.userroles.permissions.get_permissions_for_role import ( get_permissions_for_role_api, ) from supertokens_python.recipe.dashboard.api.userroles.permissions.remove_permissions_from_role import ( remove_permissions_from_role_api, ) +from supertokens_python.recipe.dashboard.api.userroles.remove_user_role import ( + remove_user_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.create_role_or_add_permissions import ( + create_role_or_add_permissions_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.delete_role import ( + delete_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.get_all_roles import ( + get_all_roles_api, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -128,6 +146,8 @@ UNLINK_USER, USERROLES_PERMISSIONS_API, USERROLES_REMOVE_PERMISSIONS_API, + USERROLES_ROLE_API, + USERROLES_USER_API, ) from .utils import ( InputOverrideConfig, @@ -430,6 +450,42 @@ def get_apis_handled(self) -> List[APIHandled]: USERROLES_REMOVE_PERMISSIONS_API, False, ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "put", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "delete", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "get", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "get", + USERROLES_USER_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "put", + USERROLES_USER_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "delete", + USERROLES_USER_API, + False, + ), ] async def handle_api_request( @@ -541,6 +597,20 @@ async def handle_api_request( api_function = get_permissions_for_role_api elif request_id == USERROLES_REMOVE_PERMISSIONS_API: api_function = remove_permissions_from_role_api + elif request_id == USERROLES_ROLE_API: + if method == "put": + api_function = create_role_or_add_permissions_api + if method == "delete": + api_function = delete_role_api + if method == "get": + api_function = get_all_roles_api + elif request_id == USERROLES_USER_API: + if method == "get": + api_function = get_roles_for_user_api + if method == "put": + api_function = add_role_to_user_api + if method == "delete": + api_function = remove_user_role_api if api_function is not None: return await api_key_protector( From edd294fb799c6cfaea2211ebc5286b88a20a4dd9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 17:41:36 +0530 Subject: [PATCH 042/126] more apis --- .../recipe/dashboard/api/__init__.py | 2 - .../recipe/dashboard/api/list_tenants.py | 48 ------------------- .../recipe/dashboard/api/users_get.py | 21 ++++---- .../recipe/dashboard/constants.py | 1 - .../recipe/dashboard/interfaces.py | 4 +- supertokens_python/recipe/dashboard/recipe.py | 10 ---- supertokens_python/recipe/dashboard/utils.py | 4 +- 7 files changed, 13 insertions(+), 77 deletions(-) delete mode 100644 supertokens_python/recipe/dashboard/api/list_tenants.py diff --git a/supertokens_python/recipe/dashboard/api/__init__.py b/supertokens_python/recipe/dashboard/api/__init__.py index 067fa84ec..3ae3a869a 100644 --- a/supertokens_python/recipe/dashboard/api/__init__.py +++ b/supertokens_python/recipe/dashboard/api/__init__.py @@ -31,7 +31,6 @@ from .users_count_get import handle_users_count_get_api from .users_get import handle_users_get_api from .validate_key import handle_validate_key_api -from .list_tenants import handle_list_tenants_api __all__ = [ "handle_dashboard_api", @@ -54,5 +53,4 @@ "handle_emailpassword_signout_api", "handle_get_tags", "handle_analytics_post", - "handle_list_tenants_api", ] diff --git a/supertokens_python/recipe/dashboard/api/list_tenants.py b/supertokens_python/recipe/dashboard/api/list_tenants.py deleted file mode 100644 index 810ee4c6b..000000000 --- a/supertokens_python/recipe/dashboard/api/list_tenants.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, List - -from supertokens_python.recipe.dashboard.interfaces import DashboardListTenantItem - -if TYPE_CHECKING: - from supertokens_python.recipe.dashboard.interfaces import ( - APIOptions, - APIInterface, - ) - from supertokens_python.types import APIResponse - -from supertokens_python.recipe.multitenancy.asyncio import list_all_tenants -from supertokens_python.recipe.dashboard.interfaces import ( - DashboardListTenantsGetResponse, -) - - -async def handle_list_tenants_api( - _api_implementation: APIInterface, - _tenant_id: str, - _api_options: APIOptions, - user_context: Dict[str, Any], -) -> APIResponse: - tenants = await list_all_tenants(user_context) - - final_tenants: List[DashboardListTenantItem] = [] - - for current_tenant in tenants.tenants: - dashboard_tenant = DashboardListTenantItem(current_tenant) - final_tenants.append(dashboard_tenant) - - return DashboardListTenantsGetResponse(final_tenants) diff --git a/supertokens_python/recipe/dashboard/api/users_get.py b/supertokens_python/recipe/dashboard/api/users_get.py index 6da1a961a..817f51e8c 100644 --- a/supertokens_python/recipe/dashboard/api/users_get.py +++ b/supertokens_python/recipe/dashboard/api/users_get.py @@ -15,7 +15,6 @@ import asyncio from typing import TYPE_CHECKING, Any, Awaitable, List, Dict -from typing_extensions import Literal from ...usermetadata import UserMetadataRecipe from ...usermetadata.asyncio import get_user_metadata @@ -27,7 +26,6 @@ APIOptions, APIInterface, ) - from supertokens_python.types import APIResponse from supertokens_python.exceptions import GeneralError, raise_bad_input_exception from supertokens_python.asyncio import get_users_newest_first, get_users_oldest_first @@ -38,16 +36,14 @@ async def handle_users_get_api( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], -) -> APIResponse: +) -> DashboardUsersGetResponse: _ = api_implementation limit = api_options.request.get_query_param("limit") if limit is None: raise_bad_input_exception("Missing required parameter 'limit'") - time_joined_order: Literal["ASC", "DESC"] = api_options.request.get_query_param( # type: ignore - "timeJoinedOrder", "DESC" - ) + time_joined_order = api_options.request.get_query_param("timeJoinedOrder", "DESC") if time_joined_order not in ["ASC", "DESC"]: raise_bad_input_exception("Invalid value received for 'timeJoinedOrder'") @@ -69,8 +65,11 @@ async def handle_users_get_api( try: UserMetadataRecipe.get_instance() except GeneralError: + users_with_metadata: List[UserWithMetadata] = [ + UserWithMetadata().from_user(user) for user in users_response.users + ] return DashboardUsersGetResponse( - users_response.users, users_response.next_pagination_token + users_with_metadata, users_response.next_pagination_token ) users_with_metadata: List[UserWithMetadata] = [ @@ -114,14 +113,14 @@ async def get_user_metadata_and_update_user(user_idx: int) -> None: ) -def get_search_params_from_url(url: str) -> Dict[str, str]: +def get_search_params_from_url(path: str) -> Dict[str, str]: from urllib.parse import urlparse, parse_qs - parsed_url = urlparse(url) - query_params = parse_qs(parsed_url.query) + url_object = urlparse("https://example.com" + path) + params = parse_qs(url_object.query) search_query = { key: value[0] - for key, value in query_params.items() + for key, value in params.items() if key not in ["limit", "timeJoinedOrder", "paginationToken"] } return search_query diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 10c127b59..7f7b9287d 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -12,7 +12,6 @@ SIGN_OUT_API = "/api/signout" SEARCH_TAGS_API = "/api/search/tags" DASHBOARD_ANALYTICS_API = "/api/analytics" -TENANTS_LIST_API = "/api/tenants/list" TENANT_THIRD_PARTY_CONFIG_API = "/api/thirdparty/config" TENANT_API = "/api/tenant" LIST_TENANTS_WITH_LOGIN_METHODS = "/api/tenants" diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index e77925c01..9a3435488 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -18,8 +18,6 @@ from typing_extensions import Literal from supertokens_python.recipe.multitenancy.interfaces import TenantConfig -from supertokens_python.types import AccountLinkingUser - from ...types import APIResponse if TYPE_CHECKING: @@ -90,7 +88,7 @@ class DashboardUsersGetResponse(APIResponse): def __init__( self, - users: Union[List[AccountLinkingUser], List[UserWithMetadata]], + users: List[UserWithMetadata], next_pagination_token: Optional[str], ): self.users = users diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index b547593b8..93912f5ee 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -103,7 +103,6 @@ handle_users_count_get_api, handle_users_get_api, handle_validate_key_api, - handle_list_tenants_api, ) from .api.implementation import APIImplementation from .exceptions import SuperTokensDashboardError @@ -134,7 +133,6 @@ USERS_COUNT_API, USERS_LIST_GET_API, VALIDATE_KEY_API, - TENANTS_LIST_API, TENANT_THIRD_PARTY_CONFIG_API, TENANT_API, LIST_TENANTS_WITH_LOGIN_METHODS, @@ -322,12 +320,6 @@ def get_apis_handled(self) -> List[APIHandled]: DASHBOARD_ANALYTICS_API, False, ), - APIHandled( - NormalisedURLPath(get_api_path_with_dashboard_base(TENANTS_LIST_API)), - "get", - TENANTS_LIST_API, - False, - ), APIHandled( NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), "post", @@ -563,8 +555,6 @@ async def handle_api_request( elif request_id == DASHBOARD_ANALYTICS_API: if method == "post": api_function = handle_analytics_post - elif request_id == TENANTS_LIST_API: - api_function = handle_list_tenants_api elif request_id == TENANT_API: if method == "post": api_function = create_tenant diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 7e053dbdf..b66d3d6a5 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -58,8 +58,8 @@ def from_user( first_name: Optional[str] = None, last_name: Optional[str] = None, ): - self.first_name = first_name - self.last_name = last_name + self.first_name = first_name or "" + self.last_name = last_name or "" self.user = user return self From 2a7599f2fff7390773c9796ef0b975320b365722 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 17 Sep 2024 17:45:15 +0530 Subject: [PATCH 043/126] renames AccountLinkingUser to User --- supertokens_python/asyncio/__init__.py | 6 +-- supertokens_python/auth_utils.py | 42 ++++++++--------- .../recipe/accountlinking/__init__.py | 2 +- .../recipe/accountlinking/asyncio/__init__.py | 6 +-- .../recipe/accountlinking/interfaces.py | 18 ++++---- .../recipe/accountlinking/recipe.py | 38 +++++++--------- .../accountlinking/recipe_implementation.py | 16 +++---- .../recipe/accountlinking/syncio/__init__.py | 6 +-- .../recipe/accountlinking/types.py | 6 +-- .../recipe/accountlinking/utils.py | 12 +++-- .../api/user/create/emailpassword_user.py | 4 +- .../api/user/create/passwordless_user.py | 4 +- supertokens_python/recipe/dashboard/utils.py | 10 ++--- .../recipe/emailpassword/interfaces.py | 12 ++--- .../recipe/emailpassword/recipe.py | 6 +-- .../emailpassword/recipe_implementation.py | 6 +-- .../recipe/emailverification/recipe.py | 4 +- .../recipe_implementation.py | 4 +- .../recipe/multifactorauth/interfaces.py | 6 +-- .../recipe/multifactorauth/recipe.py | 6 +-- .../multifactorauth/recipe_implementation.py | 6 +-- .../recipe/multifactorauth/types.py | 8 ++-- .../recipe/passwordless/api/implementation.py | 10 ++--- .../recipe/passwordless/interfaces.py | 6 +-- .../recipe/passwordless/recipe.py | 12 +++-- .../passwordless/recipe_implementation.py | 4 +- .../recipe/thirdparty/interfaces.py | 8 ++-- .../thirdparty/recipe_implementation.py | 4 +- supertokens_python/recipe/thirdparty/types.py | 4 +- supertokens_python/recipe/totp/recipe.py | 6 +-- supertokens_python/syncio/__init__.py | 6 +-- supertokens_python/types.py | 45 ++----------------- supertokens_python/utils.py | 4 +- tests/auth-react/flask-server/app.py | 4 +- 34 files changed, 140 insertions(+), 201 deletions(-) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index f92abece2..7fa32322e 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -26,7 +26,7 @@ ) from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult -from supertokens_python.types import AccountInfo, AccountLinkingUser +from supertokens_python.types import AccountInfo, User async def get_users_oldest_first( @@ -97,7 +97,7 @@ async def delete_user( async def get_user( user_id: str, user_context: Optional[Dict[str, Any]] = None -) -> Optional[AccountLinkingUser]: +) -> Optional[User]: if user_context is None: user_context = {} return await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( @@ -162,7 +162,7 @@ async def list_users_by_account_info( account_info: AccountInfo, do_union_of_account_info: bool = False, user_context: Optional[Dict[str, Any]] = None, -) -> List[AccountLinkingUser]: +) -> List[User]: if user_context is None: user_context = {} return await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 92f8de1f8..9b22c6342 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -26,7 +26,7 @@ from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.types import ( AccountInfo, - AccountLinkingUser, + User, LoginMethod, ) from supertokens_python.recipe.accountlinking.interfaces import ( @@ -85,7 +85,7 @@ class SignInNotAllowedResponse: async def pre_auth_checks( authenticating_account_info: AccountInfoWithRecipeId, - authenticating_user: Union[AccountLinkingUser, None], + authenticating_user: Union[User, None], tenant_id: str, factor_ids: List[str], is_sign_up: bool, @@ -196,11 +196,9 @@ async def pre_auth_checks( class PostAuthChecksOkResponse: status: Literal["OK"] session: SessionContainer - user: AccountLinkingUser + user: User - def __init__( - self, status: Literal["OK"], session: SessionContainer, user: AccountLinkingUser - ): + def __init__(self, status: Literal["OK"], session: SessionContainer, user: User): self.status = status self.session = session self.user = user @@ -211,7 +209,7 @@ class PostAuthChecksSignInNotAllowedResponse: async def post_auth_checks( - authenticated_user: AccountLinkingUser, + authenticated_user: User, recipe_user_id: RecipeUserId, is_sign_up: bool, factor_id: str, @@ -280,9 +278,7 @@ async def post_auth_checks( class AuthenticatingUserInfo: - def __init__( - self, user: AccountLinkingUser, login_method: Union[LoginMethod, None] - ): + def __init__(self, user: User, login_method: Union[LoginMethod, None]): self.user = user self.login_method = login_method @@ -457,9 +453,9 @@ class OkSecondFactorLinkedResponse: status: Literal["OK"] is_first_factor: Literal[False] input_user_already_linked_to_session_user: Literal[True] - session_user: AccountLinkingUser + session_user: User - def __init__(self, session_user: AccountLinkingUser): + def __init__(self, session_user: User): self.session_user = session_user @@ -467,12 +463,12 @@ class OkSecondFactorNotLinkedResponse: status: Literal["OK"] is_first_factor: Literal[False] input_user_already_linked_to_session_user: Literal[False] - session_user: AccountLinkingUser + session_user: User linking_to_session_user_requires_verification: bool def __init__( self, - session_user: AccountLinkingUser, + session_user: User, linking_to_session_user_requires_verification: bool, ): self.session_user = session_user @@ -484,7 +480,7 @@ def __init__( async def check_auth_type_and_linking_status( session: Union[SessionContainer, None], account_info: AccountInfoWithRecipeId, - input_user: Union[AccountLinkingUser, None], + input_user: Union[User, None], skip_session_user_update_in_core: bool, user_context: Dict[str, Any], ) -> Union[ @@ -494,7 +490,7 @@ async def check_auth_type_and_linking_status( LinkingToSessionUserFailedError, ]: log_debug_message("check_auth_type_and_linking_status called") - session_user: Union[AccountLinkingUser, None] = None + session_user: Union[User, None] = None if session is None: log_debug_message( "check_auth_type_and_linking_status returning first factor because there is no session" @@ -559,16 +555,16 @@ async def check_auth_type_and_linking_status( class OkResponse2: status: Literal["OK"] - user: AccountLinkingUser + user: User - def __init__(self, user: AccountLinkingUser): + def __init__(self, user: User): self.status = "OK" - self.user: AccountLinkingUser = user + self.user: User = user async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( tenant_id: str, - input_user: AccountLinkingUser, + input_user: User, recipe_user_id: RecipeUserId, session: Union[SessionContainer, None], user_context: Dict[str, Any], @@ -775,8 +771,8 @@ async def try_and_make_session_user_into_a_primary_user( async def try_linking_by_session( linking_to_session_user_requires_verification: bool, auth_login_method: LoginMethod, - authenticated_user: AccountLinkingUser, - session_user: AccountLinkingUser, + authenticated_user: User, + session_user: User, user_context: Dict[str, Any], ) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: log_debug_message("tryLinkingBySession called") @@ -872,7 +868,7 @@ async def filter_out_invalid_first_factors_or_throw_if_all_are_invalid( async def filter_out_invalid_second_factors_or_throw_if_all_are_invalid( factor_ids: List[str], input_user_already_linked_to_session_user: bool, - session_user: AccountLinkingUser, + session_user: User, session: SessionContainer, user_context: Dict[str, Any], ) -> List[str]: diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index 71ffa8119..866437269 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -21,7 +21,7 @@ from .recipe import AccountLinkingRecipe InputOverrideConfig = utils.InputOverrideConfig -AccountLinkingUser = types.AccountLinkingUser +AccountLinkingUser = types.User RecipeLevelUser = types.RecipeLevelUser AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId SessionContainer = types.SessionContainer diff --git a/supertokens_python/recipe/accountlinking/asyncio/__init__.py b/supertokens_python/recipe/accountlinking/asyncio/__init__.py index b104f37d5..af8b2054b 100644 --- a/supertokens_python/recipe/accountlinking/asyncio/__init__.py +++ b/supertokens_python/recipe/accountlinking/asyncio/__init__.py @@ -13,7 +13,7 @@ # under the License. from typing import Any, Dict, Optional -from ..types import AccountInfoWithRecipeId, AccountLinkingUser, RecipeUserId +from ..types import AccountInfoWithRecipeId, User, RecipeUserId from ..recipe import AccountLinkingRecipe from supertokens_python.recipe.session import SessionContainer from supertokens_python.asyncio import get_user @@ -24,7 +24,7 @@ async def create_primary_user_id_or_link_accounts( recipe_user_id: RecipeUserId, session: Optional[SessionContainer] = None, user_context: Optional[Dict[str, Any]] = None, -) -> AccountLinkingUser: +) -> User: if user_context is None: user_context = {} user = await get_user(recipe_user_id.get_as_string(), user_context) @@ -46,7 +46,7 @@ async def get_primary_user_that_can_be_linked_to_recipe_user_id( tenant_id: str, recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, -) -> Optional[AccountLinkingUser]: +) -> Optional[User]: if user_context is None: user_context = {} user = await get_user(recipe_user_id.get_as_string(), user_context) diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index c5ccd1f4b..8b28e4915 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from supertokens_python.types import ( - AccountLinkingUser, + User, RecipeUserId, AccountInfo, ) @@ -96,7 +96,7 @@ async def unlink_account( @abstractmethod async def get_user( self, user_id: str, user_context: Dict[str, Any] - ) -> Optional[AccountLinkingUser]: + ) -> Optional[User]: pass @abstractmethod @@ -106,7 +106,7 @@ async def list_users_by_account_info( account_info: AccountInfo, do_union_of_account_info: bool, user_context: Dict[str, Any], - ) -> List[AccountLinkingUser]: + ) -> List[User]: pass @abstractmethod @@ -120,9 +120,7 @@ async def delete_user( class GetUsersResult: - def __init__( - self, users: List[AccountLinkingUser], next_pagination_token: Optional[str] - ): + def __init__(self, users: List[User], next_pagination_token: Optional[str]): self.users = users self.next_pagination_token = next_pagination_token @@ -152,7 +150,7 @@ def __init__(self, primary_user_id: str, description: str): class CreatePrimaryUserOkResult: - def __init__(self, user: AccountLinkingUser, was_already_a_primary_user: bool): + def __init__(self, user: User, was_already_a_primary_user: bool): self.status: Literal["OK"] = "OK" self.user = user self.was_already_a_primary_user = was_already_a_primary_user @@ -213,7 +211,7 @@ def __init__(self, description: Optional[str] = None): class LinkAccountsOkResult: - def __init__(self, accounts_already_linked: bool, user: AccountLinkingUser): + def __init__(self, accounts_already_linked: bool, user: User): self.status: Literal["OK"] = "OK" self.accounts_already_linked = accounts_already_linked self.user = user @@ -223,7 +221,7 @@ class LinkAccountsRecipeUserIdAlreadyLinkedError: def __init__( self, primary_user_id: Optional[str] = None, - user: Optional[AccountLinkingUser] = None, + user: Optional[User] = None, description: Optional[str] = None, ): self.status: Literal[ @@ -238,7 +236,7 @@ class LinkAccountsAccountInfoAlreadyAssociatedError: def __init__( self, primary_user_id: Optional[str] = None, - user: Optional[AccountLinkingUser] = None, + user: Optional[User] = None, description: Optional[str] = None, ): self.status: Literal[ diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 3f4a06292..8c9154665 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo - from supertokens_python.types import AccountLinkingUser, LoginMethod, RecipeUserId + from supertokens_python.types import User, LoginMethod, RecipeUserId from supertokens_python.recipe.session import SessionContainer from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.emailverification.recipe import ( @@ -64,9 +64,7 @@ def __init__( class TryLinkingByAccountInfoOrCreatePrimaryUserResult: - def __init__( - self, status: Literal["OK", "NO_LINK"], user: Optional[AccountLinkingUser] - ): + def __init__(self, status: Literal["OK", "NO_LINK"], user: Optional[User]): self.status: Literal["OK", "NO_LINK"] = status self.user = user @@ -80,15 +78,13 @@ def __init__( recipe_id: str, app_info: AppInfo, on_account_linked: Optional[ - Callable[ - [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] - ] + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] ] = None, should_do_automatic_account_linking: Optional[ Callable[ [ AccountInfoWithRecipeIdAndUserId, - Optional[AccountLinkingUser], + Optional[User], Optional[SessionContainer], str, Dict[str, Any], @@ -152,15 +148,13 @@ def get_all_cors_headers(self) -> List[str]: @staticmethod def init( on_account_linked: Optional[ - Callable[ - [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] - ] + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] ] = None, should_do_automatic_account_linking: Optional[ Callable[ [ AccountInfoWithRecipeIdAndUserId, - Optional[AccountLinkingUser], + Optional[User], Optional[SessionContainer], str, Dict[str, Any], @@ -206,9 +200,9 @@ def reset(): async def get_primary_user_that_can_be_linked_to_recipe_user_id( self, tenant_id: str, - user: AccountLinkingUser, + user: User, user_context: Dict[str, Any], - ) -> Optional[AccountLinkingUser]: + ) -> Optional[User]: # First we check if this user itself is a primary user or not. If it is, we return that. if user.is_primary_user: return user @@ -262,9 +256,9 @@ async def get_primary_user_that_can_be_linked_to_recipe_user_id( async def get_oldest_user_that_can_be_linked_to_recipe_user( self, tenant_id: str, - user: AccountLinkingUser, + user: User, user_context: Dict[str, Any], - ) -> Optional[AccountLinkingUser]: + ) -> Optional[User]: # First we check if this user itself is a primary user or not. If it is, we return that since it cannot be linked to anything else if user.is_primary_user: return user @@ -287,7 +281,7 @@ async def get_oldest_user_that_can_be_linked_to_recipe_user( async def is_sign_in_allowed( self, - user: AccountLinkingUser, + user: User, account_info: Union[AccountInfoWithRecipeId, LoginMethod], tenant_id: str, session: Optional[SessionContainer], @@ -343,7 +337,7 @@ async def is_sign_in_up_allowed_helper( session: Optional[SessionContainer], tenant_id: str, is_sign_in: bool, - user: Optional[AccountLinkingUser], + user: Optional[User], user_context: Dict[str, Any], ) -> bool: ProcessState.get_instance().add_state( @@ -529,7 +523,7 @@ async def is_sign_in_up_allowed_helper( async def is_email_change_allowed( self, - user: AccountLinkingUser, + user: User, new_email: str, is_verified: bool, session: Optional[SessionContainer], @@ -694,7 +688,7 @@ async def is_email_change_allowed( # pylint:disable=no-self-use async def verify_email_for_recipe_user_if_linked_accounts_are_verified( self, - user: AccountLinkingUser, + user: User, recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> None: @@ -739,7 +733,7 @@ async def verify_email_for_recipe_user_if_linked_accounts_are_verified( async def should_become_primary_user( self, - user: AccountLinkingUser, + user: User, tenant_id: str, session: Optional[SessionContainer], user_context: Dict[str, Any], @@ -776,7 +770,7 @@ async def should_become_primary_user( async def try_linking_by_account_info_or_create_primary_user( self, - input_user: AccountLinkingUser, + input_user: User, session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 8db051480..91c1db61f 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -35,7 +35,7 @@ LinkAccountsAccountInfoAlreadyAssociatedError, LinkAccountsInputUserNotPrimaryError, UnlinkAccountOkResult, - AccountLinkingUser, + User, RecipeUserId, AccountInfo, ) @@ -87,7 +87,7 @@ async def get_users( ) return GetUsersResult( - users=[AccountLinkingUser.from_json(u) for u in response["users"]], + users=[User.from_json(u) for u in response["users"]], next_pagination_token=response.get("nextPaginationToken"), ) @@ -142,7 +142,7 @@ async def create_primary_user( if response["status"] == "OK": return CreatePrimaryUserOkResult( - AccountLinkingUser.from_json(response["user"]), + User.from_json(response["user"]), response["wasAlreadyAPrimaryUser"], ) elif ( @@ -227,7 +227,7 @@ async def link_accounts( "OK", "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", ]: - response["user"] = AccountLinkingUser.from_json(response["user"]) + response["user"] = User.from_json(response["user"]) if response["status"] == "OK": user = response["user"] @@ -310,7 +310,7 @@ async def unlink_account( async def get_user( self, user_id: str, user_context: Dict[str, Any] - ) -> Optional[AccountLinkingUser]: + ) -> Optional[User]: response = await self.querier.send_get_request( NormalisedURLPath("/user/id"), { @@ -319,7 +319,7 @@ async def get_user( user_context, ) if response["status"] == "OK": - return AccountLinkingUser.from_json(response["user"]) + return User.from_json(response["user"]) return None async def list_users_by_account_info( @@ -328,7 +328,7 @@ async def list_users_by_account_info( account_info: AccountInfo, do_union_of_account_info: bool, user_context: Dict[str, Any], - ) -> List[AccountLinkingUser]: + ) -> List[User]: params = { "email": account_info.email, "phoneNumber": account_info.phone_number, @@ -345,7 +345,7 @@ async def list_users_by_account_info( user_context, ) - return [AccountLinkingUser.from_json(u) for u in response["users"]] + return [User.from_json(u) for u in response["users"]] async def delete_user( self, diff --git a/supertokens_python/recipe/accountlinking/syncio/__init__.py b/supertokens_python/recipe/accountlinking/syncio/__init__.py index 832b5dfad..9153a612a 100644 --- a/supertokens_python/recipe/accountlinking/syncio/__init__.py +++ b/supertokens_python/recipe/accountlinking/syncio/__init__.py @@ -15,7 +15,7 @@ from supertokens_python.async_to_sync_wrapper import sync -from ..types import AccountInfoWithRecipeId, AccountLinkingUser, RecipeUserId +from ..types import AccountInfoWithRecipeId, User, RecipeUserId from supertokens_python.recipe.session import SessionContainer @@ -24,7 +24,7 @@ def create_primary_user_id_or_link_accounts( recipe_user_id: RecipeUserId, session: Optional[SessionContainer] = None, user_context: Optional[Dict[str, Any]] = None, -) -> AccountLinkingUser: +) -> User: from ..asyncio import ( create_primary_user_id_or_link_accounts as async_create_primary_user_id_or_link_accounts, ) @@ -40,7 +40,7 @@ def get_primary_user_that_can_be_linked_to_recipe_user_id( tenant_id: str, recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, -) -> Optional[AccountLinkingUser]: +) -> Optional[User]: from ..asyncio import ( get_primary_user_that_can_be_linked_to_recipe_user_id as async_get_primary_user_that_can_be_linked_to_recipe_user_id, ) diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 0b380bac7..748c70059 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -24,7 +24,7 @@ from supertokens_python.types import ( RecipeUserId, ThirdPartyInfo, - AccountLinkingUser, + User, LoginMethod, ) from supertokens_python.recipe.session import SessionContainer @@ -134,12 +134,12 @@ class AccountLinkingConfig: def __init__( self, on_account_linked: Callable[ - [AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] ], should_do_automatic_account_linking: Callable[ [ AccountInfoWithRecipeIdAndUserId, - Optional[AccountLinkingUser], + Optional[User], Optional[SessionContainer], str, Dict[str, Any], diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 7e39bfc40..5a5efe47f 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -16,7 +16,7 @@ from .types import ( AccountLinkingConfig, - AccountLinkingUser, + User, RecipeLevelUser, AccountInfoWithRecipeIdAndUserId, SessionContainer, @@ -30,9 +30,7 @@ from supertokens_python.supertokens import AppInfo -async def default_on_account_linked( - _: AccountLinkingUser, __: RecipeLevelUser, ___: Dict[str, Any] -): +async def default_on_account_linked(_: User, __: RecipeLevelUser, ___: Dict[str, Any]): pass @@ -41,7 +39,7 @@ async def default_on_account_linked( async def default_should_do_automatic_account_linking( _: AccountInfoWithRecipeIdAndUserId, - ___: Optional[AccountLinkingUser], + ___: Optional[User], ____: Optional[SessionContainer], _____: str, ______: Dict[str, Any], @@ -56,13 +54,13 @@ def recipe_init_defined_should_do_automatic_account_linking() -> bool: def validate_and_normalise_user_input( _: AppInfo, on_account_linked: Optional[ - Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] ] = None, should_do_automatic_account_linking: Optional[ Callable[ [ AccountInfoWithRecipeIdAndUserId, - Optional[AccountLinkingUser], + Optional[User], Optional[SessionContainer], str, Dict[str, Any], diff --git a/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py index 35717cdca..3fe8a55af 100644 --- a/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py +++ b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py @@ -8,11 +8,11 @@ SignUpOkResult, ) from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe -from supertokens_python.types import APIResponse, AccountLinkingUser, RecipeUserId +from supertokens_python.types import APIResponse, User, RecipeUserId class CreateEmailPasswordUserOkResponse(APIResponse): - def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): + def __init__(self, user: User, recipe_user_id: RecipeUserId): self.status = "OK" self.user = user self.recipe_user_id = recipe_user_id diff --git a/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py index 51bed90b7..00e470194 100644 --- a/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py +++ b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py @@ -24,7 +24,7 @@ ) from supertokens_python.recipe.passwordless.asyncio import signinup from supertokens_python.recipe.passwordless.recipe import PasswordlessRecipe -from supertokens_python.types import APIResponse, AccountLinkingUser, RecipeUserId +from supertokens_python.types import APIResponse, User, RecipeUserId from phonenumbers import parse as parse_phone_number, format_number, PhoneNumberFormat @@ -32,7 +32,7 @@ class CreatePasswordlessUserOkResponse(APIResponse): def __init__( self, created_new_recipe_user: bool, - user: AccountLinkingUser, + user: User, recipe_user_id: RecipeUserId, ): self.status: Literal["OK"] = "OK" diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index b66d3d6a5..ba7eae3bf 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -22,7 +22,7 @@ from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.passwordless import PasswordlessRecipe from supertokens_python.recipe.thirdparty import ThirdPartyRecipe -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId from supertokens_python.utils import log_debug_message, normalise_email from ...normalised_url_path import NormalisedURLPath @@ -48,13 +48,13 @@ class UserWithMetadata: - user: AccountLinkingUser + user: User first_name: Optional[str] = None last_name: Optional[str] = None def from_user( self, - user: AccountLinkingUser, + user: User, first_name: Optional[str] = None, last_name: Optional[str] = None, ): @@ -183,9 +183,7 @@ def get_api_if_matched(path: NormalisedURLPath, method: str) -> Optional[str]: class GetUserForRecipeIdHelperResult: - def __init__( - self, user: Optional[AccountLinkingUser] = None, recipe: Optional[str] = None - ): + def __init__(self, user: Optional[User] = None, recipe: Optional[str] = None): self.user = user self.recipe = recipe diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 42f71fba5..2cdbda29a 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -31,14 +31,14 @@ from supertokens_python.recipe.session import SessionContainer from .types import FormField - from ...types import AccountLinkingUser + from ...types import User from .utils import EmailPasswordConfig class SignUpOkResult: status: str = "OK" - def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): + def __init__(self, user: User, recipe_user_id: RecipeUserId): self.user = user self.recipe_user_id = recipe_user_id @@ -51,7 +51,7 @@ def to_json(self) -> Dict[str, Any]: class SignInOkResult: - def __init__(self, user: AccountLinkingUser, recipe_user_id: RecipeUserId): + def __init__(self, user: User, recipe_user_id: RecipeUserId): self.user = user self.recipe_user_id = recipe_user_id @@ -243,7 +243,7 @@ def to_json(self) -> Dict[str, Any]: class PasswordResetPostOkResult(APIResponse): status: str = "OK" - def __init__(self, email: str, user: AccountLinkingUser): + def __init__(self, email: str, user: User): self.email = email self.user = user @@ -254,7 +254,7 @@ def to_json(self) -> Dict[str, Any]: class SignInPostOkResult(APIResponse): status: str = "OK" - def __init__(self, user: AccountLinkingUser, session: SessionContainer): + def __init__(self, user: User, session: SessionContainer): self.user = user self.session = session @@ -278,7 +278,7 @@ def to_json(self) -> Dict[str, Any]: class SignUpPostOkResult(APIResponse): status: str = "OK" - def __init__(self, user: AccountLinkingUser, session: SessionContainer): + def __init__(self, user: User, session: SessionContainer): self.user = user self.session = session diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index f6a622f76..b54627641 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -36,7 +36,7 @@ from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe_module import APIHandled, RecipeModule -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId from .api.implementation import APIImplementation from .exceptions import FieldError, SuperTokensEmailPasswordError @@ -136,7 +136,7 @@ async def f1(_: TenantConfig): ) async def get_factors_setup_for_user( - user: AccountLinkingUser, _: Dict[str, Any] + user: User, _: Dict[str, Any] ) -> List[str]: for login_method in user.login_methods: # We don't check for tenant_id here because if we find the user @@ -154,7 +154,7 @@ async def get_factors_setup_for_user( ) async def get_emails_for_factor( - user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + user: User, session_recipe_user_id: RecipeUserId ) -> Union[ GetEmailsForFactorOkResult, GetEmailsForFactorUnknownSessionRecipeUserIdResult, diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index ea132d6aa..7feb13720 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -42,7 +42,7 @@ LinkingToSessionUserFailedError, link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, ) -from ...types import AccountLinkingUser +from ...types import User if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -114,7 +114,7 @@ async def create_new_recipe_user( ) if response["status"] == "OK": return SignUpOkResult( - user=AccountLinkingUser.from_json(response["user"]), + user=User.from_json(response["user"]), recipe_user_id=RecipeUserId(response["recipeUserId"]), ) return EmailAlreadyExistsError() @@ -191,7 +191,7 @@ async def verify_credentials( if response["status"] == "OK": return SignInOkResult( - user=AccountLinkingUser.from_json(response["user"]), + user=User.from_json(response["user"]), recipe_user_id=RecipeUserId(response["recipeUserId"]), ) diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 77620b78c..1c6e7c1a7 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -68,7 +68,7 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo from supertokens_python.types import RecipeUserId - from ...types import AccountLinkingUser, MaybeAwaitable + from ...types import User, MaybeAwaitable from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier @@ -270,7 +270,7 @@ def reset(): async def get_email_for_recipe_user_id( self, - user: Optional[AccountLinkingUser], + user: Optional[User], recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError]: diff --git a/supertokens_python/recipe/emailverification/recipe_implementation.py b/supertokens_python/recipe/emailverification/recipe_implementation.py index 097ecbe54..fce767d10 100644 --- a/supertokens_python/recipe/emailverification/recipe_implementation.py +++ b/supertokens_python/recipe/emailverification/recipe_implementation.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: from supertokens_python.querier import Querier - from supertokens_python.types import RecipeUserId, AccountLinkingUser + from supertokens_python.types import RecipeUserId, User class RecipeImplementation(RecipeInterface): @@ -42,7 +42,7 @@ def __init__( self, querier: Querier, get_email_for_recipe_user_id: Callable[ - [Optional[AccountLinkingUser], RecipeUserId, Dict[str, Any]], + [Optional[User], RecipeUserId, Dict[str, Any]], Awaitable[ Union[ GetEmailForUserIdOkResult, diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index ea8ec20b4..9e5632a2b 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -18,7 +18,7 @@ from typing import Dict, Any, Union, List, Callable, Awaitable from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import User from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union @@ -51,7 +51,7 @@ async def get_mfa_requirements_for_auth( tenant_id: str, access_token_payload: Dict[str, Any], completed_factors: Dict[str, int], - user: Callable[[], Awaitable[AccountLinkingUser]], + user: Callable[[], Awaitable[User]], factors_set_up_for_user: Callable[[], Awaitable[List[str]]], required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], @@ -70,7 +70,7 @@ async def mark_factor_as_complete_in_session( @abstractmethod async def get_factors_setup_for_user( - self, user: AccountLinkingUser, user_context: Dict[str, Any] + self, user: User, user_context: Dict[str, Any] ) -> List[str]: pass diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index 30bdcaeb1..f45db7316 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -35,7 +35,7 @@ from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId from .types import ( OverrideConfig, GetFactorsSetupForUserFromOtherRecipesFunc, @@ -231,7 +231,7 @@ def add_func_to_get_emails_for_factor_from_other_recipes( self.get_emails_for_factor_from_other_recipes_funcs.append(func) async def get_emails_for_factors( - self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + self, user: User, session_recipe_user_id: RecipeUserId ) -> Union[ GetEmailsForFactorOkResult, GetEmailsForFactorUnknownSessionRecipeUserIdResult, @@ -255,7 +255,7 @@ def add_func_to_get_phone_numbers_for_factors_from_other_recipes( self.get_phone_numbers_for_factor_from_other_recipes_funcs.append(func) async def get_phone_numbers_for_factors( - self, user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + self, user: User, session_recipe_user_id: RecipeUserId ) -> Union[ GetPhoneNumbersForFactorsOkResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index 89797a17e..b0909fc48 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -37,7 +37,7 @@ from .interfaces import RecipeInterface from .recipe import MultiFactorAuthRecipe -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import User from .utils import update_and_get_mfa_related_info_in_session @@ -121,7 +121,7 @@ def __init__( self.recipe_instance = recipe_instance async def get_factors_setup_for_user( - self, user: AccountLinkingUser, user_context: Dict[str, Any] + self, user: User, user_context: Dict[str, Any] ) -> List[str]: factor_ids: List[str] = [] for ( @@ -138,7 +138,7 @@ async def get_mfa_requirements_for_auth( tenant_id: str, access_token_payload: Dict[str, Any], completed_factors: Dict[str, int], - user: Callable[[], Awaitable[AccountLinkingUser]], + user: Callable[[], Awaitable[User]], factors_set_up_for_user: Callable[[], Awaitable[List[str]]], required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 0a82e61c2..d070d0518 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -17,7 +17,7 @@ from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from .interfaces import RecipeInterface, APIInterface from typing_extensions import Literal -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId class MFARequirementList(List[Union[Dict[str, List[str]], str]]): @@ -92,7 +92,7 @@ def __init__( class GetFactorsSetupForUserFromOtherRecipesFunc: def __init__( self, - func: Callable[[AccountLinkingUser, Dict[str, Any]], Awaitable[List[str]]], + func: Callable[[User, Dict[str, Any]], Awaitable[List[str]]], ): self.func = func @@ -120,7 +120,7 @@ class GetEmailsForFactorFromOtherRecipesFunc: def __init__( self, func: Callable[ - [AccountLinkingUser, RecipeUserId], + [User, RecipeUserId], Awaitable[ Union[ GetEmailsForFactorOkResult, @@ -147,7 +147,7 @@ class GetPhoneNumbersForFactorsFromOtherRecipesFunc: def __init__( self, func: Callable[ - [AccountLinkingUser, RecipeUserId], + [User, RecipeUserId], Awaitable[ Union[ GetPhoneNumbersForFactorsOkResult, diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 9b0ed20c4..16a4e8d29 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -64,7 +64,7 @@ from supertokens_python.recipe.session.exceptions import UnauthorisedError from supertokens_python.types import ( AccountInfo, - AccountLinkingUser, + User, GeneralErrorResponse, LoginMethod, RecipeUserId, @@ -74,12 +74,10 @@ class PasswordlessUserResult: - user: AccountLinkingUser + user: User login_method: Union[LoginMethod, None] - def __init__( - self, user: AccountLinkingUser, login_method: Union[LoginMethod, None] - ): + def __init__(self, user: User, login_method: Union[LoginMethod, None]): self.user = user self.login_method = login_method @@ -691,7 +689,7 @@ async def check_credentials(_: str): reason = reason_dict[response.reason] return SignInUpPostNotAllowedResponse(reason=reason) - authenticating_user_input: AccountLinkingUser + authenticating_user_input: User if response.user: authenticating_user_input = response.user elif authenticating_user: diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 5247eef4c..a1635a5b3 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -24,7 +24,7 @@ from supertokens_python.recipe.session import SessionContainer from supertokens_python.types import ( APIResponse, - AccountLinkingUser, + User, GeneralErrorResponse, RecipeUserId, ) @@ -116,7 +116,7 @@ class ConsumeCodeOkResult: def __init__( self, created_new_recipe_user: bool, - user: AccountLinkingUser, + user: User, recipe_user_id: RecipeUserId, consumed_device: ConsumedDevice, ): @@ -400,7 +400,7 @@ class ConsumeCodePostOkResult(APIResponse): def __init__( self, created_new_recipe_user: bool, - user: AccountLinkingUser, + user: User, session: SessionContainer, ): self.created_new_recipe_user = created_new_recipe_user diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index 1d59d3483..9e96b0e5a 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -42,7 +42,7 @@ from typing_extensions import Literal from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId from .api import ( consume_code, @@ -171,11 +171,9 @@ async def f1(_: TenantConfig): ) async def get_factors_setup_for_user( - user: AccountLinkingUser, _: Dict[str, Any] + user: User, _: Dict[str, Any] ) -> List[str]: - def is_factor_setup_for_user( - user: AccountLinkingUser, factor_id: str - ) -> bool: + def is_factor_setup_for_user(user: User, factor_id: str) -> bool: for login_method in user.login_methods: if login_method.recipe_id != "passwordless": continue @@ -210,7 +208,7 @@ def is_factor_setup_for_user( ) async def get_emails_for_factor( - user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + user: User, session_recipe_user_id: RecipeUserId ) -> Union[ GetEmailsForFactorOkResult, GetEmailsForFactorUnknownSessionRecipeUserIdResult, @@ -315,7 +313,7 @@ async def get_emails_for_factor( ) async def get_phone_numbers_for_factors( - user: AccountLinkingUser, session_recipe_user_id: RecipeUserId + user: User, session_recipe_user_id: RecipeUserId ) -> Union[ GetPhoneNumbersForFactorsOkResult, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 6c65a4435..46191d5c7 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -48,7 +48,7 @@ UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import AccountLinkingUser, RecipeUserId +from supertokens_python.types import User, RecipeUserId from supertokens_python.utils import log_debug_message from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe @@ -111,7 +111,7 @@ async def consume_code( recipe_user_id = RecipeUserId(response["recipeUserId"]) - updated_user = AccountLinkingUser.from_json(response["user"]) + updated_user = User.from_json(response["user"]) link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( tenant_id=tenant_id, diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index 1d106a79b..32edc6828 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -18,7 +18,7 @@ from supertokens_python.auth_utils import LinkingToSessionUserFailedError -from ...types import APIResponse, AccountLinkingUser, GeneralErrorResponse, RecipeUserId +from ...types import APIResponse, User, GeneralErrorResponse, RecipeUserId from .provider import Provider, ProviderInput, RedirectUriInfo if TYPE_CHECKING: @@ -33,7 +33,7 @@ class SignInUpOkResult: def __init__( self, - user: AccountLinkingUser, + user: User, recipe_user_id: RecipeUserId, created_new_recipe_user: bool, oauth_tokens: Dict[str, Any], @@ -49,7 +49,7 @@ def __init__( class ManuallyCreateOrUpdateUserOkResult: def __init__( self, - user: AccountLinkingUser, + user: User, recipe_user_id: RecipeUserId, created_new_recipe_user: bool, ): @@ -152,7 +152,7 @@ class SignInUpPostOkResult(APIResponse): def __init__( self, - user: AccountLinkingUser, + user: User, created_new_recipe_user: bool, session: SessionContainer, oauth_tokens: Dict[str, Any], diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index de2db33e0..becfc077c 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -30,7 +30,7 @@ find_and_create_provider_instance, merge_providers_from_core_and_static, ) -from supertokens_python.types import AccountInfo, AccountLinkingUser, RecipeUserId +from supertokens_python.types import AccountInfo, User, RecipeUserId if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -152,7 +152,7 @@ async def manually_create_or_update_user( # status is OK - user = AccountLinkingUser.from_json( + user = User.from_json( response["user"], ) recipe_user_id = RecipeUserId(response["recipeUserId"]) diff --git a/supertokens_python/recipe/thirdparty/types.py b/supertokens_python/recipe/thirdparty/types.py index 8b2be0d66..4a2dc7b9a 100644 --- a/supertokens_python/recipe/thirdparty/types.py +++ b/supertokens_python/recipe/thirdparty/types.py @@ -18,7 +18,7 @@ from supertokens_python.framework.request import BaseRequest if TYPE_CHECKING: - from supertokens_python.types import AccountLinkingUser + from supertokens_python.types import User class ThirdPartyInfo: @@ -79,7 +79,7 @@ def __init__( class SignInUpResponse: - def __init__(self, user: AccountLinkingUser, is_new_user: bool): + def __init__(self, user: User, is_new_user: bool): self.user = user self.is_new_user = is_new_user diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index d9438346c..16be078a0 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -24,7 +24,7 @@ from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.querier import Querier -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import User from .recipe_implementation import RecipeImplementation from .api.implementation import APIImplementation @@ -92,9 +92,7 @@ def callback(): async def f1(_: TenantConfig): return ["totp"] - async def f2( - user: AccountLinkingUser, user_context: Dict[str, Any] - ) -> List[str]: + async def f2(user: User, user_context: Dict[str, Any]) -> List[str]: device_res = await TOTPRecipe.get_instance_or_throw().recipe_implementation.list_devices( user_id=user.id, user_context=user_context ) diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index e7f49520b..e7b50292f 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -25,7 +25,7 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import AccountInfo, AccountLinkingUser +from supertokens_python.types import AccountInfo, User def get_users_oldest_first( @@ -96,7 +96,7 @@ def delete_user( def get_user( user_id: str, user_context: Optional[Dict[str, Any]] = None -) -> Optional[AccountLinkingUser]: +) -> Optional[User]: from supertokens_python.asyncio import get_user as async_get_user return sync(async_get_user(user_id, user_context)) @@ -167,7 +167,7 @@ def list_users_by_account_info( account_info: AccountInfo, do_union_of_account_info: bool = False, user_context: Optional[Dict[str, Any]] = None, -) -> List[AccountLinkingUser]: +) -> List[User]: from supertokens_python.asyncio import ( list_users_by_account_info as async_list_users_by_account_info, ) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 14ef4e4a4..eba2d19a3 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -131,7 +131,7 @@ def from_json(json: Dict[str, Any]) -> "LoginMethod": ) -class AccountLinkingUser: +class User: def __init__( self, user_id: str, @@ -165,8 +165,8 @@ def to_json(self) -> Dict[str, Any]: } @staticmethod - def from_json(json: Dict[str, Any]) -> "AccountLinkingUser": - return AccountLinkingUser( + def from_json(json: Dict[str, Any]) -> "User": + return User( user_id=json["id"], is_primary_user=json["isPrimaryUser"], tenant_ids=json["tenantIds"], @@ -178,45 +178,6 @@ def from_json(json: Dict[str, Any]) -> "AccountLinkingUser": ) -class User: - def __init__( - self, - recipe_id: str, - user_id: str, - time_joined: int, - email: Union[str, None], - phone_number: Union[str, None], - third_party_info: Union[ThirdPartyInfo, None], - tenant_ids: List[str], - ): - self.recipe_id = recipe_id - self.user_id = user_id - self.email = email - self.time_joined = time_joined - self.third_party_info = third_party_info - self.phone_number = phone_number - self.tenant_ids = tenant_ids - - def to_json(self) -> Dict[str, Any]: - res: Dict[str, Any] = { - "recipeId": self.recipe_id, - "user": { - "id": self.user_id, - "timeJoined": self.time_joined, - "tenantIds": self.tenant_ids, - }, - } - - if self.email is not None: - res["user"]["email"] = self.email - if self.phone_number is not None: - res["user"]["phoneNumber"] = self.phone_number - if self.third_party_info is not None: - res["user"]["thirdParty"] = self.third_party_info.__dict__ - - return res - - class APIResponse(ABC): @abstractmethod def to_json(self) -> Dict[str, Any]: diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 26fc09be6..2ea628a68 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -50,7 +50,7 @@ from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER, FDI_KEY_HEADER from .exceptions import raise_general_exception from .types import MaybeAwaitable -from supertokens_python.types import AccountLinkingUser +from supertokens_python.types import User _T = TypeVar("_T") @@ -302,7 +302,7 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str: def get_backwards_compatible_user_info( req: BaseRequest, - user_info: AccountLinkingUser, + user_info: User, session_container: SessionContainer, created_new_recipe_user: Union[bool, None], user_context: Dict[str, Any], diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index fe711cf53..dd8da11ac 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -104,7 +104,7 @@ ) from supertokens_python.types import ( AccountInfo, - AccountLinkingUser, + User, GeneralErrorResponse, ) from supertokens_python.syncio import delete_user, list_users_by_account_info @@ -200,7 +200,7 @@ async def send_email( async def create_and_send_custom_email( - _: AccountLinkingUser, url_with_token: str, __: Dict[str, Any] + _: User, url_with_token: str, __: Dict[str, Any] ) -> None: global latest_url_with_token latest_url_with_token = url_with_token From aa7cd079122e35a7b612c698d27cfbe9d4358f33 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 18 Sep 2024 15:14:11 +0530 Subject: [PATCH 044/126] more changes --- CHANGELOG.md | 3 + coreDriverInterfaceSupported.json | 4 +- frontendDriverInterfaceSupported.json | 6 +- setup.py | 2 +- supertokens_python/auth_utils.py | 61 ++++++++++++++++-- supertokens_python/constants.py | 4 +- .../emailpassword/api/implementation.py | 6 ++ .../recipe/emailpassword/api/signin.py | 23 ++++--- .../recipe/emailpassword/api/signup.py | 22 +++++-- .../recipe/emailpassword/asyncio/__init__.py | 14 ++++- .../recipe/emailpassword/interfaces.py | 4 ++ .../emailpassword/recipe_implementation.py | 4 ++ .../recipe/passwordless/api/consume_code.py | 16 +++-- .../recipe/passwordless/api/create_code.py | 22 ++++--- .../recipe/passwordless/api/implementation.py | 8 +++ .../recipe/passwordless/api/resend_code.py | 27 +++++--- .../recipe/passwordless/asyncio/__init__.py | 2 + .../recipe/passwordless/interfaces.py | 5 ++ .../recipe/passwordless/recipe.py | 3 + .../passwordless/recipe_implementation.py | 3 + .../recipe/thirdparty/api/implementation.py | 3 + .../recipe/thirdparty/api/signinup.py | 16 +++-- .../recipe/thirdparty/asyncio/__init__.py | 1 + .../recipe/thirdparty/interfaces.py | 3 + .../thirdparty/recipe_implementation.py | 4 ++ supertokens_python/utils.py | 8 +++ tests/auth-react/django3x/mysite/utils.py | 31 +++++++++- tests/auth-react/fastapi-server/app.py | 31 +++++++++- tests/auth-react/flask-server/app.py | 31 +++++++++- tests/test_user_context.py | 62 ++++++++++++++++--- 30 files changed, 363 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a21ff06c6..4fb5160fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.25.0] - 2024-09-18 + ### Breaking changes - `supertokens_python.recipe.emailverification.types.User` has been renamed to `supertokens_python.recipe.emailverification.types.EmailVerificationUser` - The user object has been changed to be a global one, containing information about all emails, phone numbers, third party info and login methods associated with that user. - Type of `get_email_for_user_id` in `emailverification.init` has changed - Session recipe's error handlers take an extra param of recipe_user_id as well - Session recipe, removes `validate_claims_in_jwt_payload` that is exposed to the user. +- TODO.. ## [0.24.2] - 2024-09-03 - Makes optional input form fields truly optional instead of just being able to accept `""`. diff --git a/coreDriverInterfaceSupported.json b/coreDriverInterfaceSupported.json index 7ebde0bbf..11cc869d7 100644 --- a/coreDriverInterfaceSupported.json +++ b/coreDriverInterfaceSupported.json @@ -1,6 +1,6 @@ { "_comment": "contains a list of core-driver interfaces branch names that this core supports", "versions": [ - "3.1" + "5.1" ] -} +} \ No newline at end of file diff --git a/frontendDriverInterfaceSupported.json b/frontendDriverInterfaceSupported.json index 0d7a266f3..71d6714d4 100644 --- a/frontendDriverInterfaceSupported.json +++ b/frontendDriverInterfaceSupported.json @@ -2,6 +2,10 @@ "_comment": "contains a list of frontend-driver interfaces branch names that this core supports", "versions": [ "1.17", - "2.0" + "1.18", + "1.19", + "2.0", + "3.0", + "3.1" ] } \ No newline at end of file diff --git a/setup.py b/setup.py index b779e2733..41143697e 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ setup( name="supertokens_python", - version="0.24.2", + version="0.25.0", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 9b22c6342..b0df1a895 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -22,7 +22,7 @@ from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant from supertokens_python.recipe.session.interfaces import SessionContainer from supertokens_python.recipe.session.recipe import SessionRecipe -from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.recipe.session.asyncio import create_new_session, get_session from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.types import ( AccountInfo, @@ -36,7 +36,7 @@ from supertokens_python.recipe.emailverification import ( EmailVerificationClaim, ) -from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.exceptions import BadInputError, raise_bad_input_exception from supertokens_python.utils import log_debug_message from .asyncio import get_user @@ -93,6 +93,7 @@ async def pre_auth_checks( sign_in_verifies_login_method: bool, skip_session_user_update_in_core: bool, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[ OkResponse, @@ -110,6 +111,7 @@ async def pre_auth_checks( log_debug_message("preAuthChecks checking auth types") auth_type_info = await check_auth_type_and_linking_status( session, + should_try_linking_with_session_user, authenticating_account_info, authenticating_user, skip_session_user_update_in_core, @@ -479,6 +481,7 @@ def __init__( async def check_auth_type_and_linking_status( session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], account_info: AccountInfoWithRecipeId, input_user: Union[User, None], skip_session_user_update_in_core: bool, @@ -492,18 +495,34 @@ async def check_auth_type_and_linking_status( log_debug_message("check_auth_type_and_linking_status called") session_user: Union[User, None] = None if session is None: + if should_try_linking_with_session_user is True: + raise UnauthorisedError( + "Session not found but shouldTryLinkingWithSessionUser is true" + ) log_debug_message( "check_auth_type_and_linking_status returning first factor because there is no session" ) return OkFirstFactorResponse() else: + if should_try_linking_with_session_user is False: + # In our normal flows this should never happen - but some user overrides might do this. + # Anyway, since should_try_linking_with_session_user explicitly set to false, it's safe to consider this a first factor + log_debug_message( + "check_auth_type_and_linking_status returning first factor because should_try_linking_with_session_user is False" + ) + return OkFirstFactorResponse() if not recipe_init_defined_should_do_automatic_account_linking(): - if MultiFactorAuthRecipe.get_instance() is not None: + if should_try_linking_with_session_user is True: raise Exception( "Please initialise the account linking recipe and define should_do_automatic_account_linking to enable MFA" ) else: - return OkFirstFactorResponse() + if MultiFactorAuthRecipe.get_instance() is not None: + raise Exception( + "Please initialise the account linking recipe and define should_do_automatic_account_linking to enable MFA" + ) + else: + return OkFirstFactorResponse() if input_user is not None and input_user.id == session.get_user_id(): log_debug_message( @@ -520,6 +539,10 @@ async def check_auth_type_and_linking_status( session, skip_session_user_update_in_core, user_context ) if session_user_result.status == "SHOULD_AUTOMATICALLY_LINK_FALSE": + if should_try_linking_with_session_user is True: + raise BadInputError( + "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + ) return OkFirstFactorResponse() elif ( session_user_result.status @@ -545,6 +568,10 @@ async def check_auth_type_and_linking_status( ) if isinstance(should_link, ShouldNotAutomaticallyLink): + if should_try_linking_with_session_user is True: + raise BadInputError( + "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + ) return OkFirstFactorResponse() else: return OkSecondFactorNotLinkedResponse( @@ -567,6 +594,7 @@ async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_acc input_user: User, recipe_user_id: RecipeUserId, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: log_debug_message( @@ -582,6 +610,7 @@ async def retry(): input_user=input_user, session=session, recipe_user_id=recipe_user_id, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) @@ -600,6 +629,7 @@ async def retry(): auth_type_res = await check_auth_type_and_linking_status( session, + should_try_linking_with_session_user, AccountInfoWithRecipeId( recipe_id=auth_login_method.recipe_id, email=auth_login_method.email, @@ -952,3 +982,26 @@ def is_fake_email(email: str) -> bool: return email.endswith("@stfakeemail.supertokens.com") or email.endswith( ".fakeemail.com" ) # .fakeemail.com for older users + + +async def load_session_in_auth_api_if_needed( + request: BaseRequest, + should_try_linking_with_session_user: Optional[bool], + user_context: Dict[str, Any], +) -> Optional[SessionContainer]: + + overwrite_session_during_sign_in_up = ( + SessionRecipe.get_instance().config.overwrite_session_during_sign_in_up + ) + + if ( + should_try_linking_with_session_user is not False + or not overwrite_session_during_sign_in_up + ): + return await get_session( + request, + session_required=should_try_linking_with_session_user is True, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + return None diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 5d19a0873..9225cf37a 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -14,8 +14,8 @@ from __future__ import annotations -SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.24.2" +SUPPORTED_CDI_VERSIONS = ["5.1"] +VERSION = "0.25.0" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index c0568e0fb..8d1a483fa 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -458,6 +458,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ @@ -527,6 +528,7 @@ async def check_credentials_on_tenant(tenant_id: str) -> bool: sign_in_verifies_login_method=False, skip_session_user_update_in_core=False, tenant_id=tenant_id, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, session=session, ) @@ -553,6 +555,7 @@ async def check_credentials_on_tenant(tenant_id: str) -> bool: session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(sign_in_response, WrongCredentialsError): @@ -589,6 +592,7 @@ async def sign_up_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ @@ -624,6 +628,7 @@ async def sign_up_post( tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if pre_auth_check_res.status == "SIGN_UP_NOT_ALLOWED": @@ -667,6 +672,7 @@ async def sign_up_post( password=password, session=session, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(sign_up_response, EmailAlreadyExistsError): diff --git a/supertokens_python/recipe/emailpassword/api/signin.py b/supertokens_python/recipe/emailpassword/api/signin.py index c41e02a15..af1e893ce 100644 --- a/supertokens_python/recipe/emailpassword/api/signin.py +++ b/supertokens_python/recipe/emailpassword/api/signin.py @@ -14,10 +14,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.recipe.emailpassword.interfaces import SignInPostOkResult -from supertokens_python.recipe.session.asyncio import get_session - if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions, @@ -27,6 +26,7 @@ from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.utils import ( get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, send_200_response, ) @@ -49,16 +49,25 @@ async def handle_sign_in_api( api_options.config.sign_in_feature.form_fields, form_fields_raw, tenant_id ) - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: tenant_id = session.get_tenant_id() response = await api_implementation.sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) if isinstance(response, SignInPostOkResult): diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index 7625dc888..529cbb866 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -14,9 +14,9 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.recipe.emailpassword.interfaces import SignUpPostOkResult -from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.types import GeneralErrorResponse from ..exceptions import raise_form_field_exception @@ -31,6 +31,7 @@ from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.utils import ( get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, send_200_response, ) @@ -53,17 +54,26 @@ async def handle_sign_up_api( api_options.config.sign_up_feature.form_fields, form_fields_raw, tenant_id ) - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: tenant_id = session.get_tenant_id() response = await api_implementation.sign_up_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) if isinstance(response, SignUpPostOkResult): diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index decb7efa1..5f4f62795 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -54,7 +54,12 @@ async def sign_up( if user_context is None: user_context = {} return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_up( - email, password, tenant_id or DEFAULT_TENANT_ID, session, user_context + email=email, + password=password, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + session=session, + user_context=user_context, + should_try_linking_with_session_user=session is not None, ) @@ -68,7 +73,12 @@ async def sign_in( if user_context is None: user_context = {} return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_in( - email, password, tenant_id or DEFAULT_TENANT_ID, session, user_context + email=email, + password=password, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + session=session, + user_context=user_context, + should_try_linking_with_session_user=session is not None, ) diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 2cdbda29a..49e5441ef 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -116,6 +116,7 @@ async def sign_up( password: str, tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[ SignUpOkResult, @@ -141,6 +142,7 @@ async def sign_in( password: str, tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError,]: pass @@ -350,6 +352,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ @@ -366,6 +369,7 @@ async def sign_up_post( form_fields: List[FormField], tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 7feb13720..f056b1e3a 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -64,6 +64,7 @@ async def sign_up( password: str, tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[ SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError @@ -84,6 +85,7 @@ async def sign_up( input_user=response.user, recipe_user_id=response.recipe_user_id, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) @@ -125,6 +127,7 @@ async def sign_in( password: str, tenant_id: str, session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: response = await self.verify_credentials( @@ -163,6 +166,7 @@ async def sign_in( input_user=response.user, recipe_user_id=response.recipe_user_id, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) diff --git a/supertokens_python/recipe/passwordless/api/consume_code.py b/supertokens_python/recipe/passwordless/api/consume_code.py index bf2c24681..e385e4ea1 100644 --- a/supertokens_python/recipe/passwordless/api/consume_code.py +++ b/supertokens_python/recipe/passwordless/api/consume_code.py @@ -12,15 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import ( APIInterface, APIOptions, ConsumeCodePostOkResult, ) -from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.utils import ( get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, send_200_response, ) @@ -64,10 +65,14 @@ async def consume_code( pre_auth_session_id = body["preAuthSessionId"] - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: @@ -79,6 +84,7 @@ async def consume_code( device_id, link_code, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, diff --git a/supertokens_python/recipe/passwordless/api/create_code.py b/supertokens_python/recipe/passwordless/api/create_code.py index d69d69749..2c57519df 100644 --- a/supertokens_python/recipe/passwordless/api/create_code.py +++ b/supertokens_python/recipe/passwordless/api/create_code.py @@ -14,7 +14,8 @@ from typing import Union, Any, Dict import phonenumbers # type: ignore -from phonenumbers import format_number, parse # type: ignore +from phonenumbers import format_number, parse +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed # type: ignore from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless.utils import ( @@ -22,9 +23,11 @@ ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, ) -from supertokens_python.recipe.session.asyncio import get_session from supertokens_python.types import GeneralErrorResponse -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def create_code( @@ -110,10 +113,14 @@ async def create_code( except Exception: phone_number = phone_number.strip() - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: @@ -126,5 +133,6 @@ async def create_code( tenant_id=tenant_id, api_options=api_options, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 16a4e8d29..c1eb9fd1e 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -139,6 +139,7 @@ async def create_code_post( email: Union[str, None], phone_number: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -205,6 +206,7 @@ async def create_code_post( factor_ids=factor_ids, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if not isinstance(pre_auth_checks_result, OkResponse): @@ -240,6 +242,7 @@ async def create_code_post( tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) magic_link = None @@ -323,6 +326,7 @@ async def resend_code_post( device_id: str, pre_auth_session_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -368,6 +372,7 @@ async def resend_code_post( ), skip_session_user_update_in_core=True, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if auth_type_info.status == "LINKING_TO_SESSION_USER_FAILED": @@ -501,6 +506,7 @@ async def consume_code_post( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -640,6 +646,7 @@ async def check_credentials(_: str): tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if not isinstance(pre_auth_checks_result, OkResponse): @@ -669,6 +676,7 @@ async def check_credentials(_: str): session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(response, ConsumeCodeRestartFlowError): diff --git a/supertokens_python/recipe/passwordless/api/resend_code.py b/supertokens_python/recipe/passwordless/api/resend_code.py index 619d8e687..8ee4cb762 100644 --- a/supertokens_python/recipe/passwordless/api/resend_code.py +++ b/supertokens_python/recipe/passwordless/api/resend_code.py @@ -12,10 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions -from supertokens_python.recipe.session.asyncio import get_session -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def resend_code( @@ -40,16 +43,26 @@ async def resend_code( pre_auth_session_id = body["preAuthSessionId"] device_id = body["deviceId"] - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: tenant_id = session.get_tenant_id() result = await api_implementation.resend_code_post( - device_id, pre_auth_session_id, session, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/asyncio/__init__.py b/supertokens_python/recipe/passwordless/asyncio/__init__.py index 0c977872b..8f29be622 100644 --- a/supertokens_python/recipe/passwordless/asyncio/__init__.py +++ b/supertokens_python/recipe/passwordless/asyncio/__init__.py @@ -64,6 +64,7 @@ async def create_code( tenant_id=tenant_id, session=session, user_context=user_context, + should_try_linking_with_session_user=session is not None, ) @@ -111,6 +112,7 @@ async def consume_code( link_code=link_code, tenant_id=tenant_id, session=session, + should_try_linking_with_session_user=session is not None, user_context=user_context, ) diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index a1635a5b3..162dd0b3e 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -209,6 +209,7 @@ async def create_code( phone_number: Union[None, str], user_input_code: Union[None, str], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: @@ -236,6 +237,7 @@ async def consume_code( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -500,6 +502,7 @@ async def create_code_post( email: Union[str, None], phone_number: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -514,6 +517,7 @@ async def resend_code_post( device_id: str, pre_auth_session_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -530,6 +534,7 @@ async def consume_code_post( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index 9e96b0e5a..af01c88a0 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -583,6 +583,7 @@ async def create_magic_link( user_input_code=user_input_code, user_context=user_context, session=None, + should_try_linking_with_session_user=False, ) app_info = self.get_app_info() @@ -615,6 +616,7 @@ async def signinup( tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=False, ) consume_code_result = await self.recipe_implementation.consume_code( link_code=code_info.link_code, @@ -624,6 +626,7 @@ async def signinup( tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=False, ) if isinstance(consume_code_result, ConsumeCodeOkResult): return consume_code_result diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 46191d5c7..17b1bc42e 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -66,6 +66,7 @@ async def consume_code( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -119,6 +120,7 @@ async def consume_code( recipe_user_id=recipe_user_id, session=session, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(link_result, LinkingToSessionUserFailedError): @@ -192,6 +194,7 @@ async def create_code( phone_number: Union[None, str], user_input_code: Union[None, str], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index a1dfd9bba..72809872d 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -73,6 +73,7 @@ async def sign_in_up_post( redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -194,6 +195,7 @@ async def check_credentials_on_tenant(_: str): tenant_id=tenant_id, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if not isinstance(pre_auth_checks_result, OkResponse): @@ -221,6 +223,7 @@ async def check_credentials_on_tenant(_: str): session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(signinup_response, SignInUpNotAllowed): diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index af94ccdc6..7ae2a2d2f 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict -from supertokens_python.recipe.session.asyncio import get_session +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.recipe.thirdparty.interfaces import SignInUpPostOkResult from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo @@ -24,6 +24,7 @@ from supertokens_python.exceptions import raise_bad_input_exception, BadInputError from supertokens_python.utils import ( get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, send_200_response, ) @@ -85,10 +86,14 @@ async def handle_sign_in_up_api( pkce_code_verifier=redirect_uri_info.get("pkceCodeVerifier"), ) - session = await get_session( - api_options.request, - override_global_claim_validators=lambda _, __, ___: [], - user_context=user_context, + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context ) if session is not None: @@ -102,6 +107,7 @@ async def handle_sign_in_up_api( api_options=api_options, user_context=user_context, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if isinstance(result, SignInUpPostOkResult): diff --git a/supertokens_python/recipe/thirdparty/asyncio/__init__.py b/supertokens_python/recipe/thirdparty/asyncio/__init__.py index 1568362ae..e4735ac76 100644 --- a/supertokens_python/recipe/thirdparty/asyncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/asyncio/__init__.py @@ -48,6 +48,7 @@ async def manually_create_or_update_user( session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=session is not None, ) diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index 32edc6828..1cd9fcb3a 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -91,6 +91,7 @@ async def manually_create_or_update_user( email: str, is_verified: bool, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -111,6 +112,7 @@ async def sign_in_up( oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: @@ -219,6 +221,7 @@ async def sign_in_up_post( redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index becfc077c..3d9aed2a0 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -60,6 +60,7 @@ async def sign_in_up( oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: @@ -70,6 +71,7 @@ async def sign_in_up( tenant_id=tenant_id, is_verified=is_verified, session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) @@ -99,6 +101,7 @@ async def manually_create_or_update_user( email: str, is_verified: bool, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -174,6 +177,7 @@ async def manually_create_or_update_user( recipe_user_id=recipe_user_id, session=session, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) if link_result.status != "OK": diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index 2ea628a68..b199e55a6 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -430,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): def normalise_email(email: str) -> str: return email.strip().lower() + + +def get_normalised_should_try_linking_with_session_user_flag( + req: BaseRequest, body: Dict[str, Any] +) -> Optional[bool]: + if has_greater_than_equal_to_fdi(req, "3.1"): + return body.get("shouldTryLinkingWithSessionUser", False) + return None diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 46ba367a2..27f988177 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -393,6 +393,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -406,13 +407,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -422,7 +429,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -443,6 +455,7 @@ async def sign_in_up_post( redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -457,6 +470,7 @@ async def sign_in_up_post( redirect_uri_info, oauth_tokens, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -512,6 +526,7 @@ async def consume_code_post( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -527,6 +542,7 @@ async def consume_code_post( device_id, link_code, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -536,6 +552,7 @@ async def create_code_post( email: Union[str, None], phone_number: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -546,13 +563,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, session, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -566,6 +590,7 @@ async def resend_code_post( device_id, pre_auth_session_id, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index bc363b69d..3ba30b4aa 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -448,6 +448,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -461,13 +462,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -477,7 +484,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -498,6 +510,7 @@ async def sign_in_up_post( redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -512,6 +525,7 @@ async def sign_in_up_post( redirect_uri_info, oauth_tokens, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -567,6 +581,7 @@ async def consume_code_post( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -582,6 +597,7 @@ async def consume_code_post( device_id, link_code, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -591,6 +607,7 @@ async def create_code_post( email: Union[str, None], phone_number: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -601,13 +618,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, session, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -621,6 +645,7 @@ async def resend_code_post( device_id, pre_auth_session_id, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index dd8da11ac..c44758c08 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -399,6 +399,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -412,13 +413,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -428,7 +435,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -449,6 +461,7 @@ async def sign_in_up_post( redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -463,6 +476,7 @@ async def sign_in_up_post( redirect_uri_info, oauth_tokens, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -518,6 +532,7 @@ async def consume_code_post( device_id: Union[str, None], link_code: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -533,6 +548,7 @@ async def consume_code_post( device_id, link_code, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -542,6 +558,7 @@ async def create_code_post( email: Union[str, None], phone_number: Union[str, None], session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -552,13 +569,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, session, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -572,6 +596,7 @@ async def resend_code_post( device_id, pre_auth_session_id, session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, diff --git a/tests/test_user_context.py b/tests/test_user_context.py index e5089c72c..039843f0d 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -11,7 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pathlib import Path from fastapi import FastAPI @@ -75,12 +75,18 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): user_context = {"preSignInPOST": True} response = await og_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) if ( "preSignInPOST" in user_context @@ -105,13 +111,19 @@ async def sign_up_( password: str, tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ): if "manualCall" in user_context: global signUpContextWorks signUpContextWorks = True response = await og_sign_up( - email, password, tenant_id, session, user_context + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, ) return response @@ -120,12 +132,18 @@ async def sign_in( password: str, tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ): if "preSignInPOST" in user_context: user_context["preSignIn"] = True response = await og_sign_in( - email, password, tenant_id, session, user_context + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, ) if "preSignInPOST" in user_context and "preSignIn" in user_context: user_context["postSignIn"] = True @@ -223,6 +241,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -232,7 +251,12 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) param.sign_in_post = sign_in_post @@ -246,6 +270,7 @@ async def sign_in( password: str, tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ): req = user_context.get("_default", {}).get("request") @@ -253,7 +278,14 @@ async def sign_in( nonlocal signin_context_works signin_context_works = True - return await og_sign_in(email, password, tenant_id, session, user_context) + return await og_sign_in( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) param.sign_in = sign_in return param @@ -343,6 +375,7 @@ async def sign_in_post( form_fields: List[FormField], tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -354,7 +387,12 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, session, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) param.sign_in_post = sign_in_post @@ -368,6 +406,7 @@ async def sign_in( password: str, tenant_id: str, session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], user_context: Dict[str, Any], ): req = get_request_from_user_context(user_context) @@ -385,7 +424,14 @@ async def sign_in( user_context["_default"]["request"] = orginal_request - return await og_sign_in(email, password, tenant_id, session, user_context) + return await og_sign_in( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) param.sign_in = sign_in return param From 8d7b6303799ea9f8ecef6010e98028e4a55c0f3d Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 18 Sep 2024 15:20:43 +0530 Subject: [PATCH 045/126] small change --- supertokens_python/recipe/session/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 13f3d8dca..7c5f1efd5 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -567,7 +567,7 @@ def anti_csrf_function( ( overwrite_session_during_sign_in_up if overwrite_session_during_sign_in_up is not None - else False + else True ), ) From 406172c3cd3d4e5d9a40b17b386b13df4c21180f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 23 Sep 2024 12:48:03 +0530 Subject: [PATCH 046/126] fixes a bunch of cyclic imports --- supertokens_python/auth_utils.py | 2 +- .../recipe/accountlinking/__init__.py | 11 +++-- .../recipe/accountlinking/asyncio/__init__.py | 3 +- .../accountlinking/recipe_implementation.py | 6 +-- .../recipe/accountlinking/types.py | 2 +- .../recipe/accountlinking/utils.py | 24 ++++++----- .../multifactorauth/asyncio/__init__.py | 2 +- .../recipe/multifactorauth/interfaces.py | 14 ++----- .../multifactorauth/recipe_implementation.py | 6 +-- .../recipe/multifactorauth/types.py | 10 +++-- .../recipe/multifactorauth/utils.py | 40 ++++++++++++------- .../recipe/multitenancy/api/implementation.py | 2 +- .../recipe/thirdparty/api/implementation.py | 19 ++++----- .../recipe/thirdparty/api/signinup.py | 3 +- .../recipe/thirdparty/interfaces.py | 3 +- .../thirdparty/recipe_implementation.py | 11 +++-- supertokens_python/recipe/totp/interfaces.py | 16 ++++---- supertokens_python/recipe/totp/recipe.py | 2 +- .../recipe/totp/recipe_implementation.py | 6 +-- supertokens_python/types.py | 9 +++-- 20 files changed, 101 insertions(+), 90 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index b0df1a895..2707ff347 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -29,7 +29,7 @@ User, LoginMethod, ) -from supertokens_python.recipe.accountlinking.interfaces import ( +from supertokens_python.types import ( RecipeUserId, ) from supertokens_python.recipe.session.exceptions import UnauthorisedError diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py index 866437269..c8aaedec6 100644 --- a/supertokens_python/recipe/accountlinking/__init__.py +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -16,28 +16,27 @@ from typing import Callable, Union, Optional, Dict, Any, Awaitable from . import types +from ...types import User +from ..session.interfaces import SessionContainer -from . import utils from .recipe import AccountLinkingRecipe -InputOverrideConfig = utils.InputOverrideConfig -AccountLinkingUser = types.User +InputOverrideConfig = types.InputOverrideConfig RecipeLevelUser = types.RecipeLevelUser AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId -SessionContainer = types.SessionContainer ShouldAutomaticallyLink = types.ShouldAutomaticallyLink ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink def init( on_account_linked: Optional[ - Callable[[AccountLinkingUser, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] ] = None, should_do_automatic_account_linking: Optional[ Callable[ [ AccountInfoWithRecipeIdAndUserId, - Optional[AccountLinkingUser], + Optional[User], Optional[SessionContainer], str, Dict[str, Any], diff --git a/supertokens_python/recipe/accountlinking/asyncio/__init__.py b/supertokens_python/recipe/accountlinking/asyncio/__init__.py index af8b2054b..549950447 100644 --- a/supertokens_python/recipe/accountlinking/asyncio/__init__.py +++ b/supertokens_python/recipe/accountlinking/asyncio/__init__.py @@ -13,7 +13,8 @@ # under the License. from typing import Any, Dict, Optional -from ..types import AccountInfoWithRecipeId, User, RecipeUserId +from ..types import AccountInfoWithRecipeId +from supertokens_python.types import User, RecipeUserId from ..recipe import AccountLinkingRecipe from supertokens_python.recipe.session import SessionContainer from supertokens_python.asyncio import get_user diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 91c1db61f..562437a16 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -35,12 +35,10 @@ LinkAccountsAccountInfoAlreadyAssociatedError, LinkAccountsInputUserNotPrimaryError, UnlinkAccountOkResult, - User, - RecipeUserId, - AccountInfo, ) from supertokens_python.normalised_url_path import NormalisedURLPath -from .types import AccountLinkingConfig, RecipeLevelUser +from .types import AccountLinkingConfig, RecipeLevelUser, AccountInfo +from supertokens_python.types import User, RecipeUserId if TYPE_CHECKING: from supertokens_python.querier import Querier diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 748c70059..03b8819b2 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -16,9 +16,9 @@ from typing import Callable, Dict, Any, Union, Optional, List, TYPE_CHECKING, Awaitable from typing_extensions import Literal from supertokens_python.recipe.accountlinking.interfaces import ( - AccountInfo, RecipeInterface, ) +from supertokens_python.types import AccountInfo if TYPE_CHECKING: from supertokens_python.types import ( diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 5a5efe47f..d06a6555f 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -14,17 +14,17 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Awaitable -from .types import ( - AccountLinkingConfig, - User, - RecipeLevelUser, - AccountInfoWithRecipeIdAndUserId, - SessionContainer, - ShouldNotAutomaticallyLink, - ShouldAutomaticallyLink, - InputOverrideConfig, - OverrideConfig, -) +if TYPE_CHECKING: + from .types import ( + AccountLinkingConfig, + User, + RecipeLevelUser, + AccountInfoWithRecipeIdAndUserId, + SessionContainer, + ShouldNotAutomaticallyLink, + ShouldAutomaticallyLink, + InputOverrideConfig, + ) if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo @@ -70,6 +70,8 @@ def validate_and_normalise_user_input( ] = None, override: Union[InputOverrideConfig, None] = None, ) -> AccountLinkingConfig: + from .types import OverrideConfig + global _did_use_default_should_do_automatic_account_linking if override is None: override = InputOverrideConfig() diff --git a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py index 9bdbc5e36..56817eb5f 100644 --- a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py +++ b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py @@ -18,7 +18,7 @@ from supertokens_python.recipe.session import SessionContainer -from ..interfaces import ( +from ..types import ( MFARequirementList, ) from ..recipe import MultiFactorAuthRecipe diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 9e5632a2b..9f237ca98 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -15,22 +15,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Any, Union, List, Callable, Awaitable -from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe - -from supertokens_python.types import User - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Union - -from ...supertokens import AppInfo - +from typing import Dict, Any, Union, List, Callable, Awaitable, TYPE_CHECKING from ...types import APIResponse, GeneralErrorResponse if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer from .types import MFARequirementList, MultiFactorAuthConfig + from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe + from ...supertokens import AppInfo + from supertokens_python.types import User class RecipeInterface(ABC): diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index b0909fc48..e2bc64251 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -34,15 +34,13 @@ ) from supertokens_python.recipe.session import SessionContainer -from .interfaces import RecipeInterface - -from .recipe import MultiFactorAuthRecipe from supertokens_python.types import User from .utils import update_and_get_mfa_related_info_in_session - +from .interfaces import RecipeInterface if TYPE_CHECKING: from supertokens_python.querier import Querier + from .recipe import MultiFactorAuthRecipe class Validator(SessionClaimValidator): diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index d070d0518..d43477291 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -11,21 +11,23 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - -from typing import Awaitable, Dict, Any, Union, List, Optional, Callable +from __future__ import annotations +from typing import Awaitable, Dict, Any, Union, List, Optional, Callable, TYPE_CHECKING from supertokens_python.recipe.multitenancy.interfaces import TenantConfig -from .interfaces import RecipeInterface, APIInterface from typing_extensions import Literal from supertokens_python.types import User, RecipeUserId +if TYPE_CHECKING: + from .interfaces import RecipeInterface, APIInterface + class MFARequirementList(List[Union[Dict[str, List[str]], str]]): def __init__( self, *args: Union[ str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]] - ] + ], ): super().__init__() for arg in args: diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 23b4d30d4..cd9dcbad4 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -15,9 +15,6 @@ from typing import TYPE_CHECKING, List, Optional, Union from typing import Dict, Any, Union, List -from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( - MultiFactorAuthClaim, -) from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import get_session_information @@ -25,9 +22,6 @@ from supertokens_python.recipe.multitenancy.asyncio import get_tenant from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.types import FactorIds -from supertokens_python.recipe.multifactorauth.recipe import ( - MultiFactorAuthRecipe as Recipe, -) from supertokens_python.recipe.multifactorauth.types import ( MFAClaimValue, MFARequirementList, @@ -80,6 +74,13 @@ async def update_and_get_mfa_related_info_in_session( input_session: Optional[SessionContainer] = None, input_updated_factor_id: Optional[str] = None, ) -> UpdateAndGetMFARelatedInfoInSessionResult: + from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + ) + from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe as Recipe, + ) + session_recipe_user_id: RecipeUserId tenant_id: str access_token_payload: Dict[str, Any] @@ -207,19 +208,28 @@ async def get_required_secondary_factors_for_tenant( else [] ) + async def get_factors_setup_for_user() -> List[str]: + return await Recipe.get_instance_or_throw_error().recipe_implementation.get_factors_setup_for_user( + user=(await user_getter()), user_context=user_context + ) + + async def get_required_secondary_factors_for_user() -> List[str]: + return await Recipe.get_instance_or_throw_error().recipe_implementation.get_required_secondary_factors_for_user( + user_id=(await user_getter()).id, user_context=user_context + ) + + async def get_required_secondary_factors_for_tenant_helper() -> List[str]: + return await get_required_secondary_factors_for_tenant( + tenant_id=tenant_id, user_context=user_context + ) + mfa_requirements_for_auth = await Recipe.get_instance_or_throw_error().recipe_implementation.get_mfa_requirements_for_auth( tenant_id=tenant_id, access_token_payload=access_token_payload, user=user_getter, - factors_set_up_for_user=lambda: Recipe.get_instance_or_throw_error().recipe_implementation.get_factors_setup_for_user( - user=(await user_getter()), user_context=user_context - ), - required_secondary_factors_for_user=lambda: Recipe.get_instance_or_throw_error().recipe_implementation.get_required_secondary_factors_for_user( - user_id=(await user_getter()).id, user_context=user_context - ), - required_secondary_factors_for_tenant=lambda: get_required_secondary_factors_for_tenant( - tenant_id, user_context - ), + factors_set_up_for_user=get_factors_setup_for_user, + required_secondary_factors_for_user=get_required_secondary_factors_for_user, + required_secondary_factors_for_tenant=get_required_secondary_factors_for_tenant_helper, completed_factors=completed_factors, user_context=user_context, ) diff --git a/supertokens_python/recipe/multitenancy/api/implementation.py b/supertokens_python/recipe/multitenancy/api/implementation.py index 5806afb57..27a58624f 100644 --- a/supertokens_python/recipe/multitenancy/api/implementation.py +++ b/supertokens_python/recipe/multitenancy/api/implementation.py @@ -25,7 +25,6 @@ from supertokens_python.types import GeneralErrorResponse from ..interfaces import APIInterface, ThirdPartyProvider -from ...multifactorauth.utils import is_valid_first_factor class APIImplementation(APIInterface): @@ -36,6 +35,7 @@ async def login_methods_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[LoginMethodsGetOkResult, GeneralErrorResponse]: + from ...multifactorauth.utils import is_valid_first_factor from supertokens_python.recipe.thirdparty.providers.config_utils import ( merge_providers_from_core_and_static, find_and_create_provider_instance, diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index 72809872d..f5266191f 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -17,15 +17,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union from urllib.parse import parse_qs, urlencode, urlparse -from supertokens_python.auth_utils import ( - OkResponse, - PostAuthChecksOkResponse, - SignInNotAllowedResponse, - SignUpNotAllowedResponse, - get_authenticating_user_and_add_to_current_tenant_if_required, - post_auth_checks, - pre_auth_checks, -) from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId from supertokens_python.recipe.emailverification import EmailVerificationRecipe @@ -83,6 +74,16 @@ async def sign_in_up_post( SignInUpNotAllowed, GeneralErrorResponse, ]: + from supertokens_python.auth_utils import ( + OkResponse, + PostAuthChecksOkResponse, + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + get_authenticating_user_and_add_to_current_tenant_if_required, + post_auth_checks, + pre_auth_checks, + ) + error_code_map = { "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_006)", "SIGN_IN_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_004)", diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index 7ae2a2d2f..2d62de3cd 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -14,7 +14,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict -from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.recipe.thirdparty.interfaces import SignInUpPostOkResult from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo @@ -35,6 +34,8 @@ async def handle_sign_in_up_api( api_options: APIOptions, user_context: Dict[str, Any], ): + from supertokens_python.auth_utils import load_session_in_auth_api_if_needed + if api_implementation.disable_sign_in_up_post: return None diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index 1cd9fcb3a..9a943c3aa 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -16,8 +16,6 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from supertokens_python.auth_utils import LinkingToSessionUserFailedError - from ...types import APIResponse, User, GeneralErrorResponse, RecipeUserId from .provider import Provider, ProviderInput, RedirectUriInfo @@ -25,6 +23,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer from supertokens_python.supertokens import AppInfo + from supertokens_python.auth_utils import LinkingToSessionUserFailedError from .types import RawUserInfoFromProvider from .utils import ThirdPartyConfig diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index 3d9aed2a0..3a0a2824d 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -15,10 +15,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from supertokens_python.asyncio import get_user, list_users_by_account_info -from supertokens_python.auth_utils import ( - LinkingToSessionUserFailedError, - link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, -) from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe @@ -34,6 +30,9 @@ if TYPE_CHECKING: from supertokens_python.querier import Querier + from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + ) from .interfaces import ( EmailChangeNotAllowedError, @@ -110,6 +109,10 @@ async def manually_create_or_update_user( SignInUpNotAllowed, EmailChangeNotAllowedError, ]: + from supertokens_python.auth_utils import ( + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, + ) + account_linking = AccountLinkingRecipe.get_instance() users = await list_users_by_account_info( tenant_id, diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index 6b35b9b47..3cb4d484e 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -12,15 +12,17 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Dict, Any, Union +from __future__ import annotations +from typing import Dict, Any, Union, TYPE_CHECKING from abc import ABC, abstractmethod -from supertokens_python import AppInfo -from supertokens_python.framework import BaseRequest, BaseResponse -from supertokens_python.recipe.session import SessionContainer -from supertokens_python.recipe.totp.recipe import TOTPRecipe -from supertokens_python.types import GeneralErrorResponse -from .types import * +if TYPE_CHECKING: + from .types import * + from supertokens_python.recipe.session import SessionContainer + from supertokens_python import AppInfo + from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.totp.recipe import TOTPRecipe + from supertokens_python.types import GeneralErrorResponse class RecipeInterface(ABC): diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py index 16be078a0..bcb5ba23c 100644 --- a/supertokens_python/recipe/totp/recipe.py +++ b/supertokens_python/recipe/totp/recipe.py @@ -40,7 +40,7 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from api.list_devices import handle_list_devices_api +from .api.list_devices import handle_list_devices_api from .api.create_device import handle_create_device_api from .api.remove_device import handle_remove_device_api from .api.verify_device import handle_verify_device_api diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index 5b4ec544f..b213c4eb2 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -19,7 +19,8 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.recipe.totp.interfaces import ( RecipeInterface, - CreateDeviceOkResult, +) +from .types import ( UnknownUserIdError, UpdateDeviceOkResult, ListDevicesOkResult, @@ -28,8 +29,7 @@ VerifyTOTPOkResult, UserIdentifierInfoOkResult, UserIdentifierInfoDoesNotExistError, -) -from .types import ( + CreateDeviceOkResult, Device, DeviceAlreadyExistsError, InvalidTOTPError, diff --git a/supertokens_python/types.py b/supertokens_python/types.py index eba2d19a3..7982acbe0 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -11,17 +11,18 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Awaitable, Dict, List, TypeVar, Union, Optional +from typing import Any, Awaitable, Dict, List, TypeVar, Union, Optional, TYPE_CHECKING from phonenumbers import format_number, parse # type: ignore import phonenumbers # type: ignore from typing_extensions import Literal -from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo - _T = TypeVar("_T") +if TYPE_CHECKING: + from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo + class RecipeUserId: def __init__(self, recipe_user_id: str): From d6ee839cc8a7267580d834dc7367b1bdf0c668f6 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 23 Sep 2024 14:16:57 +0530 Subject: [PATCH 047/126] fixes --- supertokens_python/normalised_url_path.py | 5 +++++ supertokens_python/recipe/accountlinking/recipe.py | 2 +- supertokens_python/recipe/accountlinking/utils.py | 10 +++++++--- supertokens_python/recipe/multitenancy/recipe.py | 7 ++++++- .../claim_base_classes/primitive_array_claim.py | 4 ++-- supertokens_python/recipe/userroles/recipe.py | 13 +++++++++++-- tests/test_config.py | 10 ++++++---- tests/utils.py | 6 ++++++ 8 files changed, 44 insertions(+), 13 deletions(-) diff --git a/supertokens_python/normalised_url_path.py b/supertokens_python/normalised_url_path.py index 11cf9c5c5..31075a303 100644 --- a/supertokens_python/normalised_url_path.py +++ b/supertokens_python/normalised_url_path.py @@ -54,6 +54,7 @@ def normalise_url_path_or_throw_error(input_str: str) -> str: return input_str except Exception: pass + if ( (domain_given(input_str) or input_str.startswith("localhost")) and not input_str.startswith("http://") @@ -61,8 +62,10 @@ def normalise_url_path_or_throw_error(input_str: str) -> str: ): input_str = "http://" + input_str return normalise_url_path_or_throw_error(input_str) + if not input_str.startswith("/"): input_str = "/" + input_str + try: urlparse("http://example.com" + input_str) return normalise_url_path_or_throw_error("http://example.com" + input_str) @@ -74,6 +77,8 @@ def domain_given(input_str: str) -> bool: if "." not in input_str or input_str.startswith("/"): return False try: + if not "http://" in input_str and not "https://" in input_str: + raise Exception("Trying with http") url = urlparse(input_str) return url.hostname is not None and "." in url.hostname except Exception: diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py index 8c9154665..2f98825be 100644 --- a/supertokens_python/recipe/accountlinking/recipe.py +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -102,7 +102,7 @@ def __init__( Querier.get_instance(recipe_id), self, self.config ) - self.recipe_implementation = ( + self.recipe_implementation: RecipeInterface = ( recipe_implementation if self.config.override.functions is None else self.config.override.functions(recipe_implementation) diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index d06a6555f..fd5c38cf1 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -70,17 +70,21 @@ def validate_and_normalise_user_input( ] = None, override: Union[InputOverrideConfig, None] = None, ) -> AccountLinkingConfig: - from .types import OverrideConfig + from .types import ( + OverrideConfig, + InputOverrideConfig as IOC, + AccountLinkingConfig as ALC, + ) global _did_use_default_should_do_automatic_account_linking if override is None: - override = InputOverrideConfig() + override = IOC() _did_use_default_should_do_automatic_account_linking = ( should_do_automatic_account_linking is None ) - return AccountLinkingConfig( + return ALC( override=OverrideConfig(functions=override.functions), on_account_linked=( default_on_account_linked diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 1c46bcf78..8b0fee9e4 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -21,6 +21,7 @@ PrimitiveArrayClaim, ) from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.types import RecipeUserId from ...post_init_callbacks import PostSTInitCallbacks @@ -209,7 +210,11 @@ def __init__(self): default_max_age_in_sec = 60 * 60 async def fetch_value( - _: str, tenant_id: str, user_context: Dict[str, Any] + _user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> Optional[List[str]]: recipe = MultitenancyRecipe.get_instance() diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py index f34afc750..5c13a02b6 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar, Union, Generic, List -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms from ..interfaces import ( @@ -267,7 +267,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[PrimitiveList]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index a4f4189b4..844b41895 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -27,6 +27,7 @@ from supertokens_python.recipe.userroles.utils import validate_and_normalise_user_input from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo +from supertokens_python.types import RecipeUserId from ...post_init_callbacks import PostSTInitCallbacks from ..session import SessionRecipe @@ -151,7 +152,11 @@ def __init__(self) -> None: default_max_age_in_sec = 300 async def fetch_value( - user_id: str, tenant_id: str, user_context: Dict[str, Any] + user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> List[str]: recipe = UserRolesRecipe.get_instance() @@ -186,7 +191,11 @@ def __init__(self) -> None: default_max_age_in_sec = 300 async def fetch_value( - user_id: str, tenant_id: str, user_context: Dict[str, Any] + user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> List[str]: recipe = UserRolesRecipe.get_instance() res = await recipe.recipe_implementation.get_roles_for_user( diff --git a/tests/test_config.py b/tests/test_config.py index d8a64761b..de831e405 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -718,6 +718,10 @@ async def test_samesite_valid_config(): @mark.asyncio async def test_samesite_invalid_config(): + reset() + clean_st() + setup_st() + start_st() domain_combinations = [ ["http://localhost:3000", "http://supertokensapi.io"], ["http://127.0.0.1:3000", "http://supertokensapi.io"], @@ -726,9 +730,7 @@ async def test_samesite_invalid_config(): ["http://supertokens.io", "http://supertokensapi.io"], ] for website_domain, api_domain in domain_combinations: - reset() - clean_st() - setup_st() + reset(False) try: init( supertokens_config=SupertokensConfig("http://localhost:3567"), @@ -746,7 +748,7 @@ async def test_samesite_invalid_config(): ], ) await create_new_session( - "public", MagicMock(), RecipeUserId("userId"), {}, {} + MagicMock(), "public", RecipeUserId("userId"), {}, {} ) except Exception as e: assert ( diff --git a/tests/utils.py b/tests/utils.py index bc4386cfa..7d664ab39 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -41,6 +41,9 @@ from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.userroles import UserRolesRecipe from supertokens_python.utils import is_version_gte +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.totp.recipe import TOTPRecipe INSTALLATION_PATH = environ["SUPERTOKENS_PATH"] SUPERTOKENS_PROCESS_DIR = INSTALLATION_PATH + "/.started" @@ -219,6 +222,9 @@ def reset(stop_core: bool = True): DashboardRecipe.reset() PasswordlessRecipe.reset() MultitenancyRecipe.reset() + AccountLinkingRecipe.reset() + MultiFactorAuthRecipe.reset() + TOTPRecipe.reset() def get_cookie_from_response(response: Response, cookie_name: str): From 4e0e11add673a8d423a0302ca162ac674589c35f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 23 Sep 2024 14:26:42 +0530 Subject: [PATCH 048/126] fixes --- tests/test_config_without_core.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_config_without_core.py b/tests/test_config_without_core.py index 5cbe7e28c..4d3b2a7b4 100644 --- a/tests/test_config_without_core.py +++ b/tests/test_config_without_core.py @@ -1,9 +1,9 @@ from pytest import mark -from supertokens_python import InputAppInfo, Supertokens, SupertokensConfig, init +from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.recipe import session from supertokens_python.recipe.session import SessionRecipe -from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from tests.utils import reset @mark.parametrize( @@ -31,9 +31,7 @@ def test_same_site_cookie_values( api_domain: str, website_domain: str, cookie_same_site: str ): - Supertokens.reset() - SessionRecipe.reset() - MultitenancyRecipe.reset() + reset() init( supertokens_config=SupertokensConfig("http://localhost:3567"), @@ -53,6 +51,3 @@ def test_same_site_cookie_values( s = SessionRecipe.get_instance() assert s.config.get_cookie_same_site(None, {}) == cookie_same_site - SessionRecipe.reset() - MultitenancyRecipe.reset() - Supertokens.reset() From 537db954a8a83ff2abf1c21ddeef77953452b194 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 24 Sep 2024 12:24:24 +0530 Subject: [PATCH 049/126] more fixes --- supertokens_python/auth_utils.py | 4 ++-- .../recipe/accountlinking/types.py | 8 +++++--- .../recipe/accountlinking/utils.py | 6 +++++- .../recipe/emailpassword/api/signin.py | 17 ++++++++++------- .../recipe/emailpassword/api/signup.py | 17 ++++++++++------- .../recipe/passwordless/api/consume_code.py | 17 ++++++++++------- .../recipe/thirdparty/api/signinup.py | 17 ++++++++++------- supertokens_python/types.py | 14 +++++++++----- tests/test_middleware.py | 6 ++++-- 9 files changed, 65 insertions(+), 41 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 2707ff347..da67d7b87 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -447,8 +447,8 @@ async def get_authenticating_user_and_add_to_current_tenant_if_required( class OkFirstFactorResponse: - status: Literal["OK"] - is_first_factor: Literal[True] + status: Literal["OK"] = "OK" + is_first_factor: Literal[True] = True class OkSecondFactorLinkedResponse: diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 03b8819b2..922b559fb 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -91,15 +91,17 @@ def __init__( def from_account_info_or_login_method( account_info: Union[AccountInfoWithRecipeId, LoginMethod], ) -> AccountInfoWithRecipeIdAndUserId: + from supertokens_python.types import ( + LoginMethod as LM, + ) + return AccountInfoWithRecipeIdAndUserId( recipe_id=account_info.recipe_id, email=account_info.email, phone_number=account_info.phone_number, third_party=account_info.third_party, recipe_user_id=( - account_info.recipe_user_id - if isinstance(account_info, LoginMethod) - else None + account_info.recipe_user_id if isinstance(account_info, LM) else None ), ) diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index fd5c38cf1..5f0209fd1 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -44,7 +44,11 @@ async def default_should_do_automatic_account_linking( _____: str, ______: Dict[str, Any], ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: - return ShouldNotAutomaticallyLink() + from .types import ( + ShouldNotAutomaticallyLink as SNAL, + ) + + return SNAL() def recipe_init_defined_should_do_automatic_account_linking() -> bool: diff --git a/supertokens_python/recipe/emailpassword/api/signin.py b/supertokens_python/recipe/emailpassword/api/signin.py index af1e893ce..7b99d5f36 100644 --- a/supertokens_python/recipe/emailpassword/api/signin.py +++ b/supertokens_python/recipe/emailpassword/api/signin.py @@ -72,13 +72,16 @@ async def handle_sign_in_api( if isinstance(response, SignInPostOkResult): return send_200_response( - get_backwards_compatible_user_info( - req=api_options.request, - user_info=response.user, - session_container=response.session, - created_new_recipe_user=None, - user_context=user_context, - ), + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + }, api_options.response, ) diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index 529cbb866..13c9aacd1 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -78,13 +78,16 @@ async def handle_sign_up_api( if isinstance(response, SignUpPostOkResult): return send_200_response( - get_backwards_compatible_user_info( - req=api_options.request, - user_info=response.user, - session_container=response.session, - created_new_recipe_user=None, - user_context=user_context, - ), + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + }, api_options.response, ) if isinstance(response, GeneralErrorResponse): diff --git a/supertokens_python/recipe/passwordless/api/consume_code.py b/supertokens_python/recipe/passwordless/api/consume_code.py index e385e4ea1..333194200 100644 --- a/supertokens_python/recipe/passwordless/api/consume_code.py +++ b/supertokens_python/recipe/passwordless/api/consume_code.py @@ -92,13 +92,16 @@ async def consume_code( if isinstance(result, ConsumeCodePostOkResult): return send_200_response( - get_backwards_compatible_user_info( - req=api_options.request, - user_info=result.user, - session_container=result.session, - created_new_recipe_user=result.created_new_recipe_user, - user_context=user_context, - ), + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + }, api_options.response, ) diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index 2d62de3cd..04b8daa19 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -113,13 +113,16 @@ async def handle_sign_in_up_api( if isinstance(result, SignInUpPostOkResult): return send_200_response( - get_backwards_compatible_user_info( - req=api_options.request, - user_info=result.user, - session_container=result.session, - created_new_recipe_user=result.created_new_recipe_user, - user_context=user_context, - ), + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + }, api_options.response, ) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 7982acbe0..a42a9cfa4 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -120,12 +120,16 @@ def from_json(json: Dict[str, Any]) -> "LoginMethod": recipe_id=json["recipeId"], recipe_user_id=json["recipeUserId"], tenant_ids=json["tenantIds"], - email=json["email"], - phone_number=json["phoneNumber"], + email=json["email"] if "email" in json else None, + phone_number=json["phoneNumber"] if "phoneNumber" in json else None, third_party=( - ThirdPartyInfo(json["thirdParty"]["id"], json["thirdParty"]["userId"]) - if json["thirdParty"] - else None + ( + ThirdPartyInfo( + json["thirdParty"]["id"], json["thirdParty"]["userId"] + ) + if "thirdParty" in json + else None + ) ), time_joined=json["timeJoined"], verified=json["verified"], diff --git a/tests/test_middleware.py b/tests/test_middleware.py index c5c5b8003..b7911ee4c 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -190,7 +190,8 @@ async def test_wrong_rid_with_existing_api_works( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"][0] == user_info["emails"][0] + assert len(dict_response["user"]["emails"]) == 1 @mark.asyncio @@ -239,7 +240,8 @@ async def test_random_rid_with_existing_api_works( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"][0] == user_info["emails"][0] + assert len(dict_response["user"]["emails"]) == 1 @mark.asyncio From 960af0e0abe1b4d22834839691de8adf96429da5 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 24 Sep 2024 12:41:39 +0530 Subject: [PATCH 050/126] fixes --- supertokens_python/asyncio/__init__.py | 4 ++-- .../recipe/accountlinking/recipe_implementation.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index 7fa32322e..17b83c1d1 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -41,7 +41,7 @@ async def get_users_oldest_first( user_context = {} return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( tenant_id, - time_joined_order="DESC", + time_joined_order="ASC", limit=limit, pagination_token=pagination_token, include_recipe_ids=include_recipe_ids, @@ -62,7 +62,7 @@ async def get_users_newest_first( user_context = {} return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( tenant_id, - time_joined_order="ASC", + time_joined_order="DESC", limit=limit, pagination_token=pagination_token, include_recipe_ids=include_recipe_ids, diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 562437a16..4d8dc56f2 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -71,12 +71,15 @@ async def get_users( if include_recipe_ids is not None: include_recipe_ids_str = ",".join(include_recipe_ids) - params = { - "includeRecipeIds": include_recipe_ids_str, + params: Dict[str, Any] = { "timeJoinedOrder": time_joined_order, - "limit": limit, - "paginationToken": pagination_token, } + if limit is not None: + params["limit"] = limit + if pagination_token is not None: + params["paginationToken"] = pagination_token + if include_recipe_ids_str is not None: + params["includeRecipeIds"] = include_recipe_ids_str if query: params.update(query) From a2873c49220fd9fa57460238a46adc4c2cb3ae65 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 24 Sep 2024 13:24:40 +0530 Subject: [PATCH 051/126] fixes --- .../accountlinking/recipe_implementation.py | 8 ++-- .../recipe/passwordless/api/implementation.py | 3 ++ .../recipe/passwordless/interfaces.py | 4 +- .../passwordless/recipe_implementation.py | 2 +- tests/test_passwordless.py | 40 ++++++++++++++----- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 4d8dc56f2..c9125b2fb 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -330,11 +330,13 @@ async def list_users_by_account_info( do_union_of_account_info: bool, user_context: Dict[str, Any], ) -> List[User]: - params = { - "email": account_info.email, - "phoneNumber": account_info.phone_number, + params: Dict[str, Any] = { "doUnionOfAccountInfo": do_union_of_account_info, } + if account_info.email is not None: + params["email"] = account_info.email + if account_info.phone_number is not None: + params["phoneNumber"] = account_info.phone_number if account_info.third_party: params["thirdPartyId"] = account_info.third_party.id diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index c1eb9fd1e..17aaaeba7 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -130,6 +130,9 @@ async def get_passwordless_user_by_account_info( "This should never happen: multiple users exist matching the accountInfo in passwordless createCode" ) + if len(users_with_matching_login_methods) == 0: + return None + return users_with_matching_login_methods[0] diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 162dd0b3e..4a92ae6ee 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -107,8 +107,8 @@ def from_json(json: Dict[str, Any]) -> ConsumedDevice: return ConsumedDevice( pre_auth_session_id=json["preAuthSessionId"], failed_code_input_attempt_count=json["failedCodeInputAttemptCount"], - email=json["email"], - phone_number=json["phoneNumber"], + email=json["email"] if "email" in json else None, + phone_number=json["phoneNumber"] if "phoneNumber" in json else None, ) diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 17b1bc42e..371c32a14 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -217,7 +217,7 @@ async def create_code( device_id=response["deviceId"], user_input_code=response["userInputCode"], link_code=response["linkCode"], - code_life_time=response["codeLifeTime"], + code_life_time=response["codeLifetime"], time_created=response["timeCreated"], ) diff --git a/tests/test_passwordless.py b/tests/test_passwordless.py index bc376e9f9..db271ab79 100644 --- a/tests/test_passwordless.py +++ b/tests/test_passwordless.py @@ -115,12 +115,30 @@ async def send_sms( ).json() consume_code_json["user"].pop("id") - consume_code_json["user"].pop("time_joined") + consume_code_json["user"].pop("timeJoined") + consume_code_json["user"]["loginMethods"][0].pop("recipeUserId") + consume_code_json["user"]["loginMethods"][0].pop("timeJoined") assert consume_code_json == { "status": "OK", - "createdNewUser": True, - "user": {"phoneNumber": "+919494949494"}, + "createdNewRecipeUser": True, + "user": { + "isPrimaryUser": False, + "tenantIds": ["public"], + "emails": [], + "phoneNumbers": ["+919494949494"], + "thirdParty": [], + "loginMethods": [ + { + "recipeId": "passwordless", + "tenantIds": ["public"], + "email": None, + "phoneNumber": "+919494949494", + "thirdParty": None, + "verified": True, + } + ], + }, } @@ -225,7 +243,7 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - await update_user(user_id, "foo@example.com", "+919494949494") + await update_user(RecipeUserId(user_id), "foo@example.com", "+919494949494") response = await delete_phone_number_for_user(RecipeUserId(user_id)) assert isinstance(response, UpdateUserOkResult) @@ -233,7 +251,7 @@ async def send_sms( user = await list_users_by_account_info( "public", AccountInfo(phone_number="+919494949494") ) - assert user is None + assert len(user) == 0 user = await get_user(user_id) assert user is not None and user.phone_numbers == [] @@ -310,7 +328,7 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - await update_user(user_id, "hello@example.com", "+919494949494") + await update_user(RecipeUserId(user_id), "hello@example.com", "+919494949494") response = await delete_email_for_user(RecipeUserId(user_id)) assert isinstance(response, UpdateUserOkResult) @@ -318,7 +336,7 @@ async def send_sms( user = await list_users_by_account_info( "public", AccountInfo(email="hello@example.com") ) - assert user is None + assert len(user) == 0 user = await get_user(user_id) assert user is not None and user.emails == [] @@ -397,14 +415,16 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - response = await update_user(user_id, "hello@example.com", "+919494949494") + response = await update_user( + RecipeUserId(user_id), "hello@example.com", "+919494949494" + ) assert isinstance(response, UpdateUserOkResult) # Delete the email - response = await delete_email_for_user(user_id) + response = await delete_email_for_user(RecipeUserId(user_id)) # Delete the phone number (Should raise exception because deleting both of them isn't allowed) with raises(Exception) as e: - response = await delete_phone_number_for_user(user_id) + response = await delete_phone_number_for_user(RecipeUserId(user_id)) assert e.value.args[0].endswith( "You cannot clear both email and phone number of a user\n" From 36eaa1fc601c17b47c601cdef2b2f911bed44ead Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 24 Sep 2024 16:50:52 +0530 Subject: [PATCH 052/126] fixes --- supertokens_python/querier.py | 7 +++---- tests/test_querier.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 23fb006ea..c330a56ea 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -74,7 +74,6 @@ class Querier: def __init__(self, hosts: List[Host], rid_to_core: Union[None, str] = None): self.__hosts = hosts self.__rid_to_core = None - self.__global_cache_tag = get_timestamp_ms() if rid_to_core is not None: self.__rid_to_core = rid_to_core @@ -277,7 +276,7 @@ async def f(url: str, method: str) -> Response: if user_context is not None: if ( user_context.get("_default", {}).get("global_cache_tag", -1) - != self.__global_cache_tag + != Querier.__global_cache_tag ): self.invalidate_core_call_cache(user_context, False) @@ -316,7 +315,7 @@ async def f(url: str, method: str) -> Response: **user_context.get("_default", {}).get("core_call_cache", {}), unique_key: response, }, - "global_cache_tag": self.__global_cache_tag, + "global_cache_tag": Querier.__global_cache_tag, } return response @@ -443,7 +442,7 @@ def invalidate_core_call_cache( user_context.get("_default", {}).get("keep_cache_alive", False) is not True ): # there can be race conditions here, but i think we can ignore them. - self.__global_cache_tag = get_timestamp_ms() + Querier.__global_cache_tag = get_timestamp_ms() user_context["_default"] = { **user_context.get("_default", {}), diff --git a/tests/test_querier.py b/tests/test_querier.py index 7ef9323c8..0af1d19e7 100644 --- a/tests/test_querier.py +++ b/tests/test_querier.py @@ -244,14 +244,14 @@ def intercept( assert user is None assert not called_core - user = await get_user("random", user_context) + user = await get_user("random2", user_context) assert user is None assert called_core called_core = False - user = await get_user("random", user_context) + user = await get_user("random2", user_context) assert user is None assert not called_core @@ -417,7 +417,7 @@ def intercept( called_core = False - user = await get_user("random", user_context) + user = await get_user("random2", user_context) assert user is None assert called_core From 76cdb5424c27deb99878aa4fa2a518fef4c299e6 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 12:25:03 +0530 Subject: [PATCH 053/126] fixes more tests --- supertokens_python/types.py | 6 +++--- tests/Django/test_django.py | 7 +++++++ tests/dashboard/test_dashboard.py | 6 +++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index a42a9cfa4..2eb5535cc 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -116,6 +116,8 @@ def to_json(self) -> Dict[str, Any]: @staticmethod def from_json(json: Dict[str, Any]) -> "LoginMethod": + from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo as TPI + return LoginMethod( recipe_id=json["recipeId"], recipe_user_id=json["recipeUserId"], @@ -124,9 +126,7 @@ def from_json(json: Dict[str, Any]) -> "LoginMethod": phone_number=json["phoneNumber"] if "phoneNumber" in json else None, third_party=( ( - ThirdPartyInfo( - json["thirdParty"]["id"], json["thirdParty"]["userId"] - ) + TPI(json["thirdParty"]["id"], json["thirdParty"]["userId"]) if "thirdParty" in json else None ) diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index 9e221abc6..66cd27357 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -502,6 +502,7 @@ async def test_thirdparty_parsing_works(self): f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}", ) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_multiple_emails(self): init( @@ -554,6 +555,7 @@ async def test_search_with_multiple_emails(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 1) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_email_t(self): init( @@ -604,6 +606,7 @@ async def test_search_with_email_t(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 5) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_email_iresh(self): init( @@ -656,6 +659,7 @@ async def test_search_with_email_iresh(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 0) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_phone_plus_one(self): init( @@ -711,6 +715,7 @@ async def test_search_with_phone_plus_one(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 3) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_phone_one_bracket(self): init( @@ -766,6 +771,7 @@ async def test_search_with_phone_one_bracket(self): self.assertEqual(response.status_code, 200) self.assertEqual(len(data_json["users"]), 0) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_provider_google(self): init( @@ -860,6 +866,7 @@ async def test_search_with_provider_google(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 3) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_provider_google_and_phone_one(self): init( diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index f4392d449..ab8e43763 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -136,9 +136,9 @@ async def should_allow_access( res = app.get(url="/auth/dashboard/api/users?limit=5") body = res.json() assert res.status_code == 200 - assert body["users"][0]["user"]["firstName"] == "User2" - assert body["users"][1]["user"]["lastName"] == "Foo" - assert body["users"][1]["user"]["firstName"] == "User1" + assert body["users"][0]["firstName"] == "User2" + assert body["users"][1]["lastName"] == "Foo" + assert body["users"][1]["firstName"] == "User1" async def test_connection_uri_has_http_prefix_if_localhost(app: TestClient): From 160af9b7172e055cd177285da67c36913c8779f3 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 13:41:10 +0530 Subject: [PATCH 054/126] fixes more tests --- .../recipe_implementation.py | 2 +- supertokens_python/types.py | 34 +++++++ tests/emailpassword/test_emaildelivery.py | 89 +++++++++++++------ tests/emailpassword/test_emailverify.py | 24 +++-- 4 files changed, 111 insertions(+), 38 deletions(-) diff --git a/supertokens_python/recipe/emailverification/recipe_implementation.py b/supertokens_python/recipe/emailverification/recipe_implementation.py index fce767d10..ea5ab3fb3 100644 --- a/supertokens_python/recipe/emailverification/recipe_implementation.py +++ b/supertokens_python/recipe/emailverification/recipe_implementation.py @@ -31,10 +31,10 @@ ) from .types import EmailVerificationUser from supertokens_python.asyncio import get_user +from supertokens_python.types import RecipeUserId, User if TYPE_CHECKING: from supertokens_python.querier import Querier - from supertokens_python.types import RecipeUserId, User class RecipeImplementation(RecipeInterface): diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 2eb5535cc..e4b5ee639 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -70,7 +70,23 @@ def __init__( self.time_joined = time_joined self.verified = verified + def __eq__(self, other: Any) -> bool: + if isinstance(other, LoginMethod): + return ( + self.recipe_id == other.recipe_id + and self.recipe_user_id == other.recipe_user_id + and self.tenant_ids == other.tenant_ids + and self.has_same_email_as(other.email) + and self.has_same_phone_number_as(other.phone_number) + and self.has_same_third_party_info_as(other.third_party) + and self.time_joined == other.time_joined + and self.verified == other.verified + ) + return False + def has_same_email_as(self, email: Union[str, None]) -> bool: + if self.email is None and email is None: + return True if email is None: return False return ( @@ -79,6 +95,8 @@ def has_same_email_as(self, email: Union[str, None]) -> bool: ) def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: + if self.phone_number is None and phone_number is None: + return True if phone_number is None: return False @@ -95,6 +113,8 @@ def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: def has_same_third_party_info_as( self, third_party: Union[ThirdPartyInfo, None] ) -> bool: + if third_party is None and self.third_party is None: + return True if third_party is None or self.third_party is None: return False return ( @@ -157,6 +177,20 @@ def __init__( self.login_methods = login_methods self.time_joined = time_joined + def __eq__(self, other: Any) -> bool: + if isinstance(other, User): + return ( + self.id == other.id + and self.is_primary_user == other.is_primary_user + and self.tenant_ids == other.tenant_ids + and self.emails == other.emails + and self.phone_numbers == other.phone_numbers + and self.third_party == other.third_party + and self.login_methods == other.login_methods + and self.time_joined == other.time_joined + ) + return False + def to_json(self) -> Dict[str, Any]: return { "id": self.id, diff --git a/tests/emailpassword/test_emaildelivery.py b/tests/emailpassword/test_emaildelivery.py index d24e88e5f..13c21695e 100644 --- a/tests/emailpassword/test_emaildelivery.py +++ b/tests/emailpassword/test_emaildelivery.py @@ -20,6 +20,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -206,7 +207,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -250,13 +253,14 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email assert isinstance(template_vars, PasswordResetEmailTemplateVars) password_reset_url = template_vars.password_reset_link - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -319,11 +323,12 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): template_vars.user.email = "override@example.com" assert isinstance(template_vars, PasswordResetEmailTemplateVars) - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -332,7 +337,9 @@ class CustomEmailService( emailpassword.EmailDeliveryInterface[emailpassword.EmailTemplateVars] ): async def send_email( - self, template_vars: Any, user_context: Dict[str, Any] + self, + template_vars: Any, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -384,7 +391,10 @@ async def test_reset_password_smtp_service(driver_config_client: TestClient): def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -396,7 +406,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -433,11 +444,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -486,7 +498,8 @@ async def test_reset_password_for_non_existent_user(driver_config_client: TestCl def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -498,7 +511,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -535,11 +549,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -613,7 +628,7 @@ async def test_email_verification_default_backward_compatibility( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -679,7 +694,7 @@ async def test_email_verification_default_backward_compatibility_suppress_error( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -727,7 +742,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailverification.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, email_verify_url email = template_vars.user.email @@ -762,7 +779,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) res = email_verify_token_request( @@ -792,7 +809,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, email_verify_url email = template_vars.user.email @@ -833,7 +851,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -878,7 +896,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email @@ -925,7 +945,8 @@ async def test_email_verification_smtp_service(driver_config_client: TestClient) def smtp_service_override(oi: SMTPServiceInterface[VerificationEmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -937,7 +958,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: VerificationEmailTemplateVars, _user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, email_verify_url get_content_called = True @@ -974,7 +996,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True @@ -1013,7 +1036,7 @@ async def send_email_override( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( @@ -1045,7 +1068,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: PasswordResetEmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal reset_url, token_info, tenant_info, query_length password_reset_url = template_vars.password_reset_link @@ -1081,7 +1106,9 @@ async def send_email( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - resp = await send_reset_password_email("public", user_info["id"], "") + resp = await send_reset_password_email( + "public", user_info["id"], "random@gmail.com" + ) assert resp == "OK" assert reset_url == "http://supertokens.io/auth/reset-password" @@ -1118,9 +1145,13 @@ async def test_send_reset_password_email_invalid_input( dict_response = json.loads(response_1.text) user_info = dict_response["user"] - link = await send_reset_password_email("public", "invalidUserId", "") + link = await send_reset_password_email( + "public", "invalidUserId", "random@gmail.com" + ) assert link == "UNKNOWN_USER_ID_ERROR" with raises(Exception) as err: - await send_reset_password_email("invalidTenantId", user_info["id"], "") + await send_reset_password_email( + "invalidTenantId", user_info["id"], "random@gmail.com" + ) assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 0bf066d68..a87ed908b 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -198,7 +198,9 @@ async def test_the_generate_token_api_with_valid_input_email_verified_and_test_e user_id = dict_response["user"]["id"] cookies = extract_all_cookies(response_1) - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) @@ -742,7 +744,7 @@ async def email_verify_post( await asyncio.sleep(1) if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -972,7 +974,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1083,7 +1085,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1121,7 +1123,9 @@ async def test_the_generate_token_api_with_valid_input_and_then_remove_token( assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) await revoke_email_verification_tokens("public", user_id) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): @@ -1165,7 +1169,9 @@ async def test_the_generate_token_api_with_valid_input_verify_and_then_unverify_ assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) @@ -1269,7 +1275,9 @@ async def send_email( cookies = extract_all_cookies(res) # Start verification: - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) assert isinstance(verify_token, CreateEmailVerificationTokenOkResult) await verify_email_using_token("public", verify_token.token) @@ -1370,7 +1378,7 @@ def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: dict_response = json.loads(response_1.text) assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - email = dict_response["user"]["email"] + email = dict_response["user"]["emails"][0] await send_email_verification_email( "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3000"} From f36ddd0fb5a14b25256970e1ef7b2d32693ba512 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 13:46:28 +0530 Subject: [PATCH 055/126] fixes more tests --- tests/emailpassword/test_passwordreset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/emailpassword/test_passwordreset.py b/tests/emailpassword/test_passwordreset.py index 89c6af4ce..26856b78b 100644 --- a/tests/emailpassword/test_passwordreset.py +++ b/tests/emailpassword/test_passwordreset.py @@ -399,7 +399,7 @@ async def send_email( dict_response = json.loads(response_4.text) assert dict_response["status"] == "OK" assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio From 12d6bffdbda7e265060e325f754f280c2357fdea Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 13:58:25 +0530 Subject: [PATCH 056/126] fixes an implementation --- supertokens_python/types.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index e4b5ee639..936f6f191 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -76,17 +76,15 @@ def __eq__(self, other: Any) -> bool: self.recipe_id == other.recipe_id and self.recipe_user_id == other.recipe_user_id and self.tenant_ids == other.tenant_ids - and self.has_same_email_as(other.email) - and self.has_same_phone_number_as(other.phone_number) - and self.has_same_third_party_info_as(other.third_party) + and self.email == other.email + and self.phone_number == other.phone_number + and self.third_party == other.third_party and self.time_joined == other.time_joined and self.verified == other.verified ) return False def has_same_email_as(self, email: Union[str, None]) -> bool: - if self.email is None and email is None: - return True if email is None: return False return ( @@ -95,8 +93,6 @@ def has_same_email_as(self, email: Union[str, None]) -> bool: ) def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: - if self.phone_number is None and phone_number is None: - return True if phone_number is None: return False @@ -113,8 +109,6 @@ def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: def has_same_third_party_info_as( self, third_party: Union[ThirdPartyInfo, None] ) -> bool: - if third_party is None and self.third_party is None: - return True if third_party is None or self.third_party is None: return False return ( From 85221eac0b4810ac17b78a5349409cd07369d92a Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 17:59:44 +0530 Subject: [PATCH 057/126] fixes more tests --- tests/emailpassword/test_signin.py | 6 ++-- tests/multitenancy/test_tenants_crud.py | 36 ++++++------------------ tests/passwordless/test_emaildelivery.py | 7 +++-- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/tests/emailpassword/test_signin.py b/tests/emailpassword/test_signin.py index 6203022fc..42f1b86d7 100644 --- a/tests/emailpassword/test_signin.py +++ b/tests/emailpassword/test_signin.py @@ -180,7 +180,7 @@ async def test_singinAPI_works_when_input_is_fine(driver_config_client: TestClie assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio @@ -225,7 +225,7 @@ async def test_singinAPI_works_when_input_is_fine_when_rid_is_tpep( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio @@ -270,7 +270,7 @@ async def test_singinAPI_works_when_input_is_fine_when_rid_is_emailpassword( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 7127e0735..73d60a388 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -77,7 +77,7 @@ async def test_tenant_crud(): await create_or_update_tenant( "t2", TenantConfigCreateOrUpdate( - first_factors=["otp-email, otp-phone, link-email, link-phone"] + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] ), ) await create_or_update_tenant( @@ -85,39 +85,30 @@ async def test_tenant_crud(): ) tenants = await list_all_tenants() - assert len(tenants.tenants) == 3 + assert len(tenants.tenants) == 4 t1_config = await get_tenant("t1") assert t1_config is not None assert t1_config.first_factors is not None assert "emailpassword" in t1_config.first_factors - assert "otp-email" in t1_config.first_factors - assert "otp-phone" in t1_config.first_factors - assert "link-email" in t1_config.first_factors - assert "link-phone" in t1_config.first_factors - assert "thirdparty" in t1_config.first_factors + assert len(t1_config.first_factors) == 1 assert t1_config.core_config == {} t2_config = await get_tenant("t2") assert t2_config is not None assert t2_config.first_factors is not None - assert "emailpassword" in t2_config.first_factors assert "otp-email" in t2_config.first_factors assert "otp-phone" in t2_config.first_factors assert "link-email" in t2_config.first_factors assert "link-phone" in t2_config.first_factors - assert "thirdparty" in t2_config.first_factors + assert len(t2_config.first_factors) == 4 assert t2_config.core_config == {} t3_config = await get_tenant("t3") assert t3_config is not None assert t3_config.first_factors is not None - assert "emailpassword" in t3_config.first_factors - assert "otp-email" in t3_config.first_factors - assert "otp-phone" in t3_config.first_factors - assert "link-email" in t3_config.first_factors - assert "link-phone" in t3_config.first_factors assert "thirdparty" in t3_config.first_factors + assert len(t3_config.first_factors) == 1 assert t3_config.core_config == {} # update tenant1 to add passwordless: @@ -126,37 +117,26 @@ async def test_tenant_crud(): TenantConfigCreateOrUpdate( first_factors=[ "otp-email", - "otp-phone", - "link-email", - "link-phone", ] ), ) t1_config = await get_tenant("t1") assert t1_config is not None assert t1_config.first_factors is not None - assert "emailpassword" in t1_config.first_factors assert "otp-email" in t1_config.first_factors - assert "otp-phone" in t1_config.first_factors - assert "link-email" in t1_config.first_factors - assert "link-phone" in t1_config.first_factors - assert "thirdparty" in t1_config.first_factors + assert len(t1_config.first_factors) == 1 assert t1_config.core_config == {} # update tenant1 to add thirdparty: await create_or_update_tenant( - "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty", "otp-email"]) ) t1_config = await get_tenant("t1") assert t1_config is not None assert t1_config.first_factors is not None - assert "emailpassword" in t1_config.first_factors assert "otp-email" in t1_config.first_factors - assert "otp-phone" in t1_config.first_factors - assert "link-email" in t1_config.first_factors - assert "link-phone" in t1_config.first_factors assert "thirdparty" in t1_config.first_factors - assert t1_config.core_config == {} + assert len(t1_config.first_factors) == 2 assert t1_config.core_config == {} # delete tenant2: diff --git a/tests/passwordless/test_emaildelivery.py b/tests/passwordless/test_emaildelivery.py index 5fcca30d4..60f4bf8c9 100644 --- a/tests/passwordless/test_emaildelivery.py +++ b/tests/passwordless/test_emaildelivery.py @@ -59,7 +59,7 @@ create_email_verification_token, ) from supertokens_python.recipe.emailverification.interfaces import ( - CreateEmailVerificationTokenOkResult, + CreateEmailVerificationTokenEmailAlreadyVerifiedError, ) @@ -186,8 +186,9 @@ async def send_email_override( "public", pless_response.recipe_user_id ) - assert isinstance(create_token, CreateEmailVerificationTokenOkResult) - # TODO: Replaced CreateEmailVerificationTokenEmailAlreadyVerifiedError. Confirm if this is correct. + assert isinstance( + create_token, CreateEmailVerificationTokenEmailAlreadyVerifiedError + ) assert ( all([outer_override_called, get_content_called, send_raw_email_called]) is False From 19a13bfc4c8b7c809c2b30359b84c133849f74db Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 18:43:51 +0530 Subject: [PATCH 058/126] fixes more tests --- tests/sessions/claims/test_create_new_session.py | 6 +++--- tests/sessions/claims/test_primitive_array_claim.py | 11 +++++++++-- tests/sessions/claims/test_primitive_claim.py | 8 +++++++- tests/sessions/claims/test_set_claim_value.py | 4 ++-- tests/sessions/claims/test_verify_session.py | 2 +- tests/sessions/claims/utils.py | 4 ++-- tests/sessions/test_access_token_version.py | 2 +- 7 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index 1207f8335..dbe8f76c0 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -32,7 +32,7 @@ async def test_create_access_token_payload_with_session_claims(timestamp: int): s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"v": True, "t": timestamp} @@ -45,7 +45,7 @@ async def test_should_create_access_token_payload_with_session_claims_with_an_no s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 9 + assert len(payload) == 10 assert payload.get("st-true") is None @@ -70,6 +70,6 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 11 + assert len(payload) == 12 assert payload["st-true"] == {"v": True, "t": timestamp} assert payload["user-custom-claim"] == "foo" diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 91c3c5ee5..939ef9c73 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -80,9 +80,16 @@ async def test_primitive_claim_matching__add_to_payload(): async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveArrayClaim("key", sync_fetch_value) user_id, ctx = "user_id", {} - await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + recipe_user_id = RecipeUserId(user_id) + await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 - assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ + assert ( + user_id, + recipe_user_id, + DEFAULT_TENANT_ID, + ctx, + {}, + ) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 315e2b153..3fde9dd38 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -49,7 +49,13 @@ async def test_primitive_claim_fetch_value_params_correct(): user_id, ctx = "user_id", {} await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) assert sync_fetch_value.call_count == 1 - assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ + assert ( + user_id, + RecipeUserId(user_id), + DEFAULT_TENANT_ID, + ctx, + {}, + ) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 91c518f8e..35d56c654 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -62,7 +62,7 @@ async def test_should_overwrite_claim_value(timestamp: int): s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": True} await s.set_claim_value(TrueClaim, False) @@ -81,7 +81,7 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int) s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": True} await set_claim_value(s.get_handle(), TrueClaim, False) diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index 12f3b5461..ff5ac669d 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -447,7 +447,7 @@ async def test_should_allow_if_assert_claims_returns_no_error( assert validators == [validator] assert ctx["_default"]["request"] recipe_impl_mock.validate_claims.assert_called_once_with( # type: ignore - "test_user_id", {}, [validator], ctx + "test_user_id", RecipeUserId("test_user_id"), {}, [validator], ctx ) diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index b55fd7f6e..92a78c96d 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -9,8 +9,8 @@ from supertokens_python.types import RecipeUserId from tests.utils import st_init_common_args -TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore -NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___: None) # type: ignore +TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___, _____, ______: True) +NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___, _____, ______: None) def session_functions_override_with_claim( diff --git a/tests/sessions/test_access_token_version.py b/tests/sessions/test_access_token_version.py index 4000bebd5..8b532590a 100644 --- a/tests/sessions/test_access_token_version.py +++ b/tests/sessions/test_access_token_version.py @@ -47,7 +47,7 @@ async def test_access_token_v4(): False, ) assert res["userId"] == "user-id" - assert parsed_info.version == 4 + assert parsed_info.version == 5 async def test_parsing_access_token_v2(): From 4989a8fbd89e0c4d6fc96ce354cd5aa943be0b41 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 18:58:11 +0530 Subject: [PATCH 059/126] fixes more tests --- tests/thirdparty/test_emaildelivery.py | 10 +++++----- tests/thirdparty/test_multitenancy.py | 12 ++++++------ tests/thirdparty/test_thirdparty.py | 5 ++++- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/thirdparty/test_emaildelivery.py b/tests/thirdparty/test_emaildelivery.py index dd33e2bdf..207626613 100644 --- a/tests/thirdparty/test_emaildelivery.py +++ b/tests/thirdparty/test_emaildelivery.py @@ -140,7 +140,7 @@ async def test_email_verify_default_backward_compatibility( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com", True, None + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() @@ -214,7 +214,7 @@ async def test_email_verify_default_backward_compatibility_supress_error( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com", True, None + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() @@ -304,7 +304,7 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com", True, None + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() @@ -382,7 +382,7 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com", True, None + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() @@ -522,7 +522,7 @@ async def send_email_override( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com", True, None + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() diff --git a/tests/thirdparty/test_multitenancy.py b/tests/thirdparty/test_multitenancy.py index 4fbd61aa6..33666c845 100644 --- a/tests/thirdparty/test_multitenancy.py +++ b/tests/thirdparty/test_multitenancy.py @@ -178,12 +178,12 @@ async def test_thirtyparty_multitenancy_functions(): ), ) - assert g_user_by_tpid1a == user1a.user - assert g_user_by_tpid1b == user1b.user - assert g_user_by_tpid2a == user2a.user - assert g_user_by_tpid2b == user2b.user - assert g_user_by_tpid3a == user3a.user - assert g_user_by_tpid3b == user3b.user + assert g_user_by_tpid1a == [user1a.user] + assert g_user_by_tpid1b == [user1b.user] + assert g_user_by_tpid2a == [user2a.user] + assert g_user_by_tpid2b == [user2b.user] + assert g_user_by_tpid3a == [user3a.user] + assert g_user_by_tpid3b == [user3b.user] async def test_get_provider(): diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index 5cfd928cd..ce567f059 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -396,4 +396,7 @@ async def test_signinup_generating_fake_email( assert res.status_code == 200 res_json = res.json() assert res_json["status"] == "OK" - assert res_json["user"]["email"] == "customid.custom@stfakeemail.supertokens.com" + assert ( + res_json["user"]["emails"][0] == "customid.custom@stfakeemail.supertokens.com" + ) + assert len(res_json["user"]["emails"]) == 1 From d604a7f974597997b1f90f159f113b7892e33f62 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 19:11:26 +0530 Subject: [PATCH 060/126] fixes more tests --- tests/useridmapping/recipe_tests.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/useridmapping/recipe_tests.py b/tests/useridmapping/recipe_tests.py index cf4011a2d..20b1309e0 100644 --- a/tests/useridmapping/recipe_tests.py +++ b/tests/useridmapping/recipe_tests.py @@ -84,7 +84,7 @@ async def ep_get_existing_user_by_signin(email: str) -> str: async def ep_get_existing_user_after_reset_password(user_id: str) -> str: - new_password = "password" + new_password = "password1234" from supertokens_python.recipe.emailpassword.asyncio import ( create_reset_password_token, reset_password_using_token, @@ -105,10 +105,12 @@ async def ep_get_existing_user_after_updating_email_and_sign_in(user_id: str) -> sign_in, ) - res = await update_email_or_password(RecipeUserId(user_id), new_email, "password") - assert isinstance(res, SignUpOkResult) + res = await update_email_or_password( + RecipeUserId(user_id), new_email, "password1234" + ) + assert isinstance(res, UpdateEmailOrPasswordOkResult) - res = await sign_in("public", new_email, "password") + res = await sign_in("public", new_email, "password1234") assert isinstance(res, SignInOkResult) return res.user.id @@ -128,7 +130,7 @@ async def test_get_user_id_mapping(use_external_id_info: bool): external_user_id = "externalId" external_id_info = "externalIdInfo" if use_external_id_info else None - assert ep_get_existing_user_id(supertokens_user_id) == supertokens_user_id + assert await ep_get_existing_user_id(supertokens_user_id) == supertokens_user_id # Create user id mapping res = await create_user_id_mapping( @@ -138,16 +140,17 @@ async def test_get_user_id_mapping(use_external_id_info: bool): # Now we should get the external user ID instead of ST user ID # irrespective of whether we pass ST User ID or External User ID - assert ep_get_existing_user_id(supertokens_user_id) == external_user_id - assert ep_get_existing_user_id(external_user_id) == external_user_id + assert await ep_get_existing_user_id(supertokens_user_id) == external_user_id + assert await ep_get_existing_user_id(external_user_id) == external_user_id # Same happens for all the functions - assert ep_get_existing_user_by_email(email) == external_user_id - assert ep_get_existing_user_by_signin(email) == external_user_id + assert await ep_get_existing_user_by_email(email) == external_user_id + assert await ep_get_existing_user_by_signin(email) == external_user_id assert ( - ep_get_existing_user_after_reset_password(external_user_id) == external_user_id + await ep_get_existing_user_after_reset_password(external_user_id) + == external_user_id ) assert ( - ep_get_existing_user_after_updating_email_and_sign_in(external_user_id) + await ep_get_existing_user_after_updating_email_and_sign_in(external_user_id) == external_user_id ) From 05c1242c92e05d0164f17b9ee2e9af9fdad79944 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 25 Sep 2024 20:50:34 +0530 Subject: [PATCH 061/126] fixes more tests --- tests/emailpassword/test_emailverify.py | 12 ++++++------ tests/sessions/claims/test_primitive_array_claim.py | 4 ++-- tests/sessions/claims/test_set_claim_value.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index a87ed908b..06314917e 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -1126,7 +1126,7 @@ async def test_the_generate_token_api_with_valid_input_and_then_remove_token( verify_token = await create_email_verification_token( "public", RecipeUserId(user_id) ) - await revoke_email_verification_tokens("public", user_id) + await revoke_email_verification_tokens("public", RecipeUserId(user_id)) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): response = await verify_email_using_token("public", verify_token.token) @@ -1175,11 +1175,11 @@ async def test_the_generate_token_api_with_valid_input_verify_and_then_unverify_ if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) - assert await is_email_verified(user_id) + assert await is_email_verified(RecipeUserId(user_id)) - await unverify_email(user_id) + await unverify_email(RecipeUserId(user_id)) - is_verified = await is_email_verified(user_id) + is_verified = await is_email_verified(RecipeUserId(user_id)) assert is_verified is False return raise Exception("Test failed") @@ -1304,10 +1304,10 @@ async def send_email( ) assert res.status_code == 200 assert res.json() == {"status": "EMAIL_ALREADY_VERIFIED_ERROR"} - assert "front-token" not in res.headers + assert res.headers.get("front-token") is not None # now we mark the email as unverified and try again: - await unverify_email(user_id) + await unverify_email(RecipeUserId(user_id)) res = email_verify_token_request( driver_config_client, cookies["sAccessToken"]["value"], diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 939ef9c73..35c930d93 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -426,13 +426,13 @@ async def test_validator_excludes_all_should_validate_matching_payload(): async def test_validator_should_not_validate_older_values_with_5min_default_max_age( patch_get_timestamp_ms: MagicMock, ): - claim = PrimitiveArrayClaim("key", sync_fetch_value, 3000) # 5 mins + claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins payload = await claim.build( "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} ) # Increase clock time by 10 MINS: - patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore + patch_get_timestamp_ms.return_value += 10 * MINS res = await resolve(claim.validators.includes(included_item).validate(payload, {})) assert res.is_valid is False diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 35d56c654..a10677692 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -69,7 +69,7 @@ async def test_should_overwrite_claim_value(timestamp: int): # Payload should be updated now: payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": False} From 8e3a7b2216608ca2a38ca0639be276810077afab Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 01:25:47 +0530 Subject: [PATCH 062/126] fixes frontendintegration tests with st-website repo --- .vscode/launch.json | 20 ++++++++++++++++++++ supertokens_python/supertokens.py | 3 +++ 2 files changed, 23 insertions(+) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..a6f58a695 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,20 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Flask, supertokens-website tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/frontendIntegration/flask-server/app.py", + "args": [ + "--port", + "8080" + ], + "cwd": "${workspaceFolder}/tests/frontendIntegration/flask-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true + } + ] +} \ No newline at end of file diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index eb283ead3..d4509a74c 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -325,6 +325,9 @@ def reset(): environ["SUPERTOKENS_ENV"] != "testing" ): raise_general_exception("calling testing function in non testing env") + from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + + UserMetadataRecipe.reset() Querier.reset() Supertokens.__instance = None From e0995f074b853868cb0f4e559535d1c268327737 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 15:43:53 +0530 Subject: [PATCH 063/126] fixes stuff --- .pylintrc | 3 +- .vscode/launch.json | 15 + supertokens_python/__init__.py | 5 + .../recipe/multifactorauth/__init__.py | 35 ++ supertokens_python/recipe/totp/__init__.py | 33 ++ supertokens_python/supertokens.py | 9 +- tests/auth-react/flask-server/app.py | 553 ++++++++++++++++-- tests/multitenancy/test_tenants_crud.py | 8 +- ...test_validate_claims_for_session_handle.py | 2 +- tests/test-server/multitenancy.py | 2 +- tests/userroles/test_claims.py | 7 +- 11 files changed, 623 insertions(+), 49 deletions(-) create mode 100644 supertokens_python/recipe/multifactorauth/__init__.py create mode 100644 supertokens_python/recipe/totp/__init__.py diff --git a/.pylintrc b/.pylintrc index f9b838184..c271a2fb9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -123,7 +123,8 @@ disable=raw-checker-failed, consider-using-in, no-else-return, no-self-use, - no-else-raise + no-else-raise, + too-many-nested-blocks, # Enable the message, report, category or checker with the given id(s). You can diff --git a/.vscode/launch.json b/.vscode/launch.json index a6f58a695..9c01eaf88 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,6 +15,21 @@ "FLASK_DEBUG": "1" }, "jinja": true + }, + { + "name": "Python: Flask, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/flask-server/app.py", + "args": [ + "--port", + "8083" + ], + "cwd": "${workspaceFolder}/tests/auth-react/flask-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true } ] } \ No newline at end of file diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index ea7ba506a..43d3573b2 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,6 +17,7 @@ from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.types import RecipeUserId from . import supertokens from .recipe_module import RecipeModule @@ -49,3 +50,7 @@ def get_request_from_user_context( user_context: Optional[Dict[str, Any]], ) -> Optional[BaseRequest]: return Supertokens.get_instance().get_request_from_user_context(user_context) + + +def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: + return RecipeUserId(user_id) diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py new file mode 100644 index 000000000..fa3aaf6f8 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +from supertokens_python.recipe.multifactorauth.types import OverrideConfig + +from .recipe import MultiFactorAuthRecipe + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + + +def init( + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return MultiFactorAuthRecipe.init( + first_factors, + override, + ) diff --git a/supertokens_python/recipe/totp/__init__.py b/supertokens_python/recipe/totp/__init__.py new file mode 100644 index 000000000..f89944688 --- /dev/null +++ b/supertokens_python/recipe/totp/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Union + +from supertokens_python.recipe.totp.types import TOTPConfig + +from .recipe import TOTPRecipe + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + + +def init( + config: Union[TOTPConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return TOTPRecipe.init( + config=config, + ) diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index d4509a74c..6aae8aa9b 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -255,11 +255,6 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) - from supertokens_python.recipe.multifactorauth.recipe import ( - MultiFactorAuthRecipe, - ) - from supertokens_python.recipe.totp.recipe import TOTPRecipe - multitenancy_found = False totp_found = False user_metadata_found = False @@ -272,9 +267,9 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: multitenancy_found = True elif recipe_module.get_recipe_id() == "usermetadata": user_metadata_found = True - elif recipe_module.get_recipe_id() == MultiFactorAuthRecipe.recipe_id: + elif recipe_module.get_recipe_id() == "multifactorauth": multi_factor_auth_found = True - elif recipe_module.get_recipe_id() == TOTPRecipe.recipe_id: + elif recipe_module.get_recipe_id() == "totp": totp_found = True return recipe_module diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index c44758c08..55ed847c8 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv from flask import Flask, g, jsonify, make_response, request @@ -26,20 +26,56 @@ get_all_cors_headers, init, ) +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.framework.flask.flask_middleware import Middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.recipe import ( + accountlinking, emailpassword, emailverification, passwordless, session, thirdparty, + totp, userroles, ) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import ( + AccountInfoWithRecipeIdAndUserId, +) from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EmailPasswordAPIInterface, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.multitenancy.syncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.passwordless.syncio import update_user +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, ) from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo from supertokens_python.recipe.emailpassword.interfaces import ( @@ -72,6 +108,11 @@ ) from supertokens_python.recipe.passwordless.interfaces import ( APIInterface as PasswordlessAPIInterface, + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionRecipe @@ -86,12 +127,16 @@ SessionClaimValidator, SessionContainer, ) -from supertokens_python.recipe.thirdparty import ThirdPartyRecipe +from supertokens_python.recipe.thirdparty import ProviderConfig, ThirdPartyRecipe from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider +from supertokens_python.recipe.thirdparty.syncio import manually_create_or_update_user +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import ( PermissionClaim, @@ -104,10 +149,12 @@ ) from supertokens_python.types import ( AccountInfo, + RecipeUserId, User, GeneralErrorResponse, ) -from supertokens_python.syncio import delete_user, list_users_by_account_info +from supertokens_python.syncio import delete_user, get_user, list_users_by_account_info +from supertokens_python.recipe import multifactorauth load_dotenv() @@ -129,6 +176,14 @@ def get_website_domain(): latest_url_with_token = None code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None class CustomPlessEmailService( @@ -266,12 +321,11 @@ async def get_user_info( # pylint: disable=no-self-use return oi -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def custom_init(): + global contact_method + global flow_type + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() @@ -283,6 +337,8 @@ def custom_init( DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -659,6 +715,14 @@ async def resend_code_post( ), ] + global enabled_providers + if enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in enabled_providers + ] + if contact_method is not None and flow_type is not None: if contact_method == "PHONE": passwordless_init = passwordless.init( @@ -706,33 +770,243 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + global mfa_info + + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if mfa_info.get("alreadySetup"): + return mfa_info.get("alreadySetup", []) + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if mfa_info.get("allowedToSetup"): + if factor_id not in mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if mfa_info.get("requirements"): + return mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if mfa_info.get("alreadySetup"): + res.factors.already_setup = mfa_info["alreadySetup"][:] + + if mfa_info.get("noContacts"): + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailpassword.EmailDeliveryConfig(CustomEPEmailService()), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), + ), + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + global accountlinking_config + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + global enabled_recipes + if enabled_recipes is not None: + recipe_list = [ + item["init"] for item in recipe_list if item["id"] in enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( @@ -771,6 +1045,180 @@ def ping(): return "success" +@app.route("/changeEmail", methods=["POST"]) # type: ignore +def change_email(): + body: Union[Any, None] = request.get_json() + if body is None: + raise Exception("Should never come here") + from supertokens_python.recipe.emailpassword.syncio import update_email_or_password + from supertokens_python import convert_to_recipe_user_id + + if body["rid"] == "emailpassword": + resp = update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return jsonify({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + # password policy violation error + return jsonify(resp.to_json()) + elif body["rid"] == "thirdparty": + user = get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return jsonify( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return jsonify({"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason}) + # EmailChangeNotAllowedError + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return jsonify({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return jsonify( + {"status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + + raise Exception("Should not come here") + + +@app.route("/setupTenant", methods=["POST"]) # type: ignore +def setup_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body["coreConfig"] + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + if ( + len(provider) > 1 + ): # TODO: remove this once all tests pass, this is just for making sure we pass the right stuff into ProviderConfig + raise Exception("Pass more stuff into ProviderConfig:" + str(provider)) + create_or_update_third_party_config( + tenant_id, + config=ProviderConfig( + third_party_id=provider["id"], + ), + ) + + return jsonify({"status": "OK", "createdNew": core_resp.created_new}) + + +@app.route("/addUserToTenant", methods=["POST"]) # type: ignore +def add_user_to_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return jsonify( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return jsonify({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return jsonify( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +@app.route("/removeUserFromTenant", methods=["POST"]) # type: ignore +def remove_user_from_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = disassociate_user_from_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + return jsonify({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +@app.route("/removeTenant", methods=["POST"]) # type: ignore +def remove_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = delete_tenant(tenant_id) + + return jsonify({"status": "OK", "didExist": core_resp.did_exist}) + + @app.route("/sessionInfo", methods=["GET"]) # type: ignore @verify_session() def get_session_info(): @@ -794,7 +1242,15 @@ def get_token(): @app.route("/beforeeach", methods=["POST"]) # type: ignore def before_each(): global code_store + global accountlinking_config + global enabled_providers + global enabled_recipes + global mfa_info code_store = dict() + accountlinking_config = {} + enabled_providers = None + enabled_recipes = None + mfa_info = {} custom_init() return "" @@ -804,12 +1260,38 @@ def test_set_flow(): body: Union[Any, None] = request.get_json() if body is None: raise Exception("Should never come here") + global contact_method + global flow_type contact_method = body["contactMethod"] flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + custom_init() return "" +@app.route("/test/setAccountLinkingConfig", methods=["POST"]) # type: ignore +def test_set_account_linking_config(): + global accountlinking_config + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + accountlinking_config = body + custom_init() + return "", 200 + + +@app.route("/test/setEnabledRecipes", methods=["POST"]) # type: ignore +def test_set_enabled_recipes(): + global enabled_recipes + global enabled_providers + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + enabled_recipes = body.get("enabledRecipes") + enabled_providers = body.get("enabledProviders") + custom_init() + return "", 200 + + @app.get("/test/getDevice") # type: ignore def test_get_device(): global code_store @@ -828,6 +1310,11 @@ def test_feature_flags(): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] return jsonify({"available": available}) diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 73d60a388..f048fc99a 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -42,7 +42,7 @@ create_or_update_third_party_config, delete_third_party_config, associate_user_to_tenant, - dissociate_user_from_tenant, + disassociate_user_from_tenant, ) from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult @@ -305,9 +305,9 @@ async def test_user_association_and_disassociation_with_tenants(): assert user is not None assert len(user.tenant_ids) == 4 # public + 3 tenants - await dissociate_user_from_tenant("t1", RecipeUserId(user_id)) - await dissociate_user_from_tenant("t2", RecipeUserId(user_id)) - await dissociate_user_from_tenant("t3", RecipeUserId(user_id)) + await disassociate_user_from_tenant("t1", RecipeUserId(user_id)) + await disassociate_user_from_tenant("t2", RecipeUserId(user_id)) + await disassociate_user_from_tenant("t3", RecipeUserId(user_id)) user = await get_user(user_id) assert user is not None diff --git a/tests/sessions/claims/test_validate_claims_for_session_handle.py b/tests/sessions/claims/test_validate_claims_for_session_handle.py index 71c62c69d..ccad5b464 100644 --- a/tests/sessions/claims/test_validate_claims_for_session_handle.py +++ b/tests/sessions/claims/test_validate_claims_for_session_handle.py @@ -41,7 +41,7 @@ async def test_should_return_the_right_validation_errors(): ) assert isinstance(res, ClaimsValidationResult) and len(res.invalid_claims) == 1 - assert res.invalid_claims[0].id == failing_validator.id + assert res.invalid_claims[0].id_ == failing_validator.id assert res.invalid_claims[0].reason == { "message": "value does not exist", "actualValue": None, diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index ff2aac785..4da9df85a 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -203,7 +203,7 @@ def disassociate_user_from_tenant(): # type: ignore user_id = data["userId"] user_context = data.get("userContext") - response = multitenancy.dissociate_user_from_tenant( + response = multitenancy.disassociate_user_from_tenant( tenant_id, user_id, user_context ) diff --git a/tests/userroles/test_claims.py b/tests/userroles/test_claims.py index 455f915b7..db13577fb 100644 --- a/tests/userroles/test_claims.py +++ b/tests/userroles/test_claims.py @@ -135,7 +135,8 @@ async def test_should_validate_roles(): assert e.typename == "InvalidClaimsError" err: ClaimValidationError (err,) = e.value.payload # type: ignore - assert err.id == UserRoleClaim.key + assert isinstance(err, ClaimValidationError) + assert err.id_ == UserRoleClaim.key assert err.reason == { "message": "wrong value", "expectedToInclude": invalid_role, @@ -196,8 +197,10 @@ async def test_should_validate_permissions(): assert e.typename == "InvalidClaimsError" err: ClaimValidationError (err,) = e.value.payload # type: ignore - assert err.id == PermissionClaim.key + assert isinstance(err, ClaimValidationError) + assert err.id_ == PermissionClaim.key assert err.reason is not None + assert isinstance(err.reason, dict) actual_value = err.reason.pop("actualValue") assert sorted(actual_value) == sorted(permissions) assert err.reason == { From eae7e2de900467fc1d8213c9bfc2b64cd87dc6a9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 15:44:24 +0530 Subject: [PATCH 064/126] cyclic import issue --- supertokens_python/auth_utils.py | 8 +++- .../multifactorauth/api/implementation.py | 23 +++++---- .../api/resync_session_and_fetch_mfa_info.py | 2 +- .../multifactorauth/asyncio/__init__.py | 16 +++++-- .../recipe/multifactorauth/interfaces.py | 6 +-- .../multi_factor_auth_claim.py | 46 +++++++++++------- .../recipe/multifactorauth/recipe.py | 7 +-- .../multifactorauth/recipe_implementation.py | 5 +- .../recipe/multifactorauth/syncio/__init__.py | 4 -- .../recipe/multifactorauth/types.py | 4 +- .../recipe/multifactorauth/utils.py | 47 +++++++++++++------ .../recipe/multitenancy/asyncio/__init__.py | 23 +++++++-- .../recipe/multitenancy/interfaces.py | 11 ++++- .../multitenancy/recipe_implementation.py | 6 ++- .../recipe/multitenancy/syncio/__init__.py | 6 +-- .../recipe/session/exceptions.py | 11 +++-- .../recipe/thirdparty/asyncio/__init__.py | 2 +- .../recipe/thirdparty/syncio/__init__.py | 2 +- supertokens_python/recipe/totp/interfaces.py | 19 +++++++- .../recipe/totp/syncio/__init__.py | 4 -- supertokens_python/recipe/totp/types.py | 8 ---- 21 files changed, 171 insertions(+), 89 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index da67d7b87..ff3315ba3 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -923,8 +923,14 @@ async def get_factors_set_up_for_user(): async def get_mfa_requirements_for_auth(): nonlocal mfa_info_prom if mfa_info_prom is None: + from .recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + ) + mfa_info_prom = await update_and_get_mfa_related_info_in_session( - input_session=session, user_context=user_context + MultiFactorAuthClaim, + input_session=session, + user_context=user_context, ) return mfa_info_prom.mfa_requirements_for_auth diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py index 9f3dd774b..d6fe15e98 100644 --- a/supertokens_python/recipe/multifactorauth/api/implementation.py +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -12,15 +12,15 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import importlib -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import Any, Dict, List, Union, TYPE_CHECKING from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.multifactorauth.utils import ( update_and_get_mfa_related_info_in_session, ) from supertokens_python.recipe.multitenancy.asyncio import get_tenant -from ..multi_factor_auth_claim import MultiFactorAuthClaim from supertokens_python.asyncio import get_user from supertokens_python.recipe.session.exceptions import ( InvalidClaimsError, @@ -28,12 +28,6 @@ UnauthorisedError, ) -if TYPE_CHECKING: - from supertokens_python.recipe.multifactorauth.interfaces import ( - APIInterface, - APIOptions, - ) - from supertokens_python.types import GeneralErrorResponse from ..interfaces import ( APIInterface, @@ -42,6 +36,11 @@ ResyncSessionAndFetchMFAInfoPUTOkResult, ) +if TYPE_CHECKING: + from ..multi_factor_auth_claim import ( + MultiFactorAuthClaimClass as MultiFactorAuthClaimType, + ) + class APIImplementation(APIInterface): async def resync_session_and_fetch_mfa_info_put( @@ -50,6 +49,11 @@ async def resync_session_and_fetch_mfa_info_put( session: SessionContainer, user_context: Dict[str, Any], ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + + mfa = importlib.import_module("supertokens_python.recipe.multifactorauth") + + MultiFactorAuthClaim: MultiFactorAuthClaimType = mfa.MultiFactorAuthClaim + session_user = await get_user(session.get_user_id(), user_context) if session_user is None: @@ -58,6 +62,7 @@ async def resync_session_and_fetch_mfa_info_put( ) mfa_info = await update_and_get_mfa_related_info_in_session( + MultiFactorAuthClaim, input_session=session, user_context=user_context, ) @@ -144,7 +149,7 @@ async def get_mfa_requirements_for_auth(): ) return ResyncSessionAndFetchMFAInfoPUTOkResult( factors=NextFactors( - next=next_factors, + next_=next_factors, already_setup=factors_setup_for_user, allowed_to_setup=factors_allowed_to_setup, ), diff --git a/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py index 1048253e6..8d7f1e8eb 100644 --- a/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py +++ b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py @@ -27,7 +27,7 @@ async def handle_resync_session_and_fetch_mfa_info_api( - tenant_id: str, + _tenant_id: str, api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any], diff --git a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py index 56817eb5f..4c8af2e54 100644 --- a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py +++ b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py @@ -21,7 +21,6 @@ from ..types import ( MFARequirementList, ) -from ..recipe import MultiFactorAuthRecipe from ..utils import update_and_get_mfa_related_info_in_session from supertokens_python.recipe.accountlinking.asyncio import get_user @@ -34,13 +33,17 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( if user_context is None: user_context = {} + from ..multi_factor_auth_claim import MultiFactorAuthClaim + mfa_info = await update_and_get_mfa_related_info_in_session( + MultiFactorAuthClaim, input_session=session, user_context=user_context, ) factors_set_up_for_user = await get_factors_setup_for_user( session.get_user_id(), user_context ) + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() @@ -66,7 +69,10 @@ async def get_mfa_requirements_for_auth( if user_context is None: user_context = {} + from ..multi_factor_auth_claim import MultiFactorAuthClaim + mfa_info = await update_and_get_mfa_related_info_in_session( + MultiFactorAuthClaim, input_session=session, user_context=user_context, ) @@ -81,6 +87,7 @@ async def mark_factor_as_complete_in_session( ) -> None: if user_context is None: user_context = {} + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() await recipe.recipe_implementation.mark_factor_as_complete_in_session( @@ -100,6 +107,7 @@ async def get_factors_setup_for_user( user = await get_user(user_id, user_context) if user is None: raise Exception("Unknown user id") + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() return await recipe.recipe_implementation.get_factors_setup_for_user( @@ -114,6 +122,7 @@ async def get_required_secondary_factors_for_user( ) -> List[str]: if user_context is None: user_context = {} + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() return await recipe.recipe_implementation.get_required_secondary_factors_for_user( @@ -129,6 +138,7 @@ async def add_to_required_secondary_factors_for_user( ) -> None: if user_context is None: user_context = {} + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() await recipe.recipe_implementation.add_to_required_secondary_factors_for_user( @@ -145,6 +155,7 @@ async def remove_from_required_secondary_factors_for_user( ) -> None: if user_context is None: user_context = {} + from ..recipe import MultiFactorAuthRecipe recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() await recipe.recipe_implementation.remove_from_required_secondary_factors_for_user( @@ -152,6 +163,3 @@ async def remove_from_required_secondary_factors_for_user( factor_id=factor_id, user_context=user_context, ) - - -init = MultiFactorAuthRecipe.init diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py index 9f237ca98..f960e9ffe 100644 --- a/supertokens_python/recipe/multifactorauth/interfaces.py +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -123,15 +123,15 @@ async def resync_session_and_fetch_mfa_info_put( class NextFactors: def __init__( - self, next: List[str], already_setup: List[str], allowed_to_setup: List[str] + self, next_: List[str], already_setup: List[str], allowed_to_setup: List[str] ): - self.next = next + self.next_ = next_ self.already_setup = already_setup self.allowed_to_setup = allowed_to_setup def to_json(self) -> Dict[str, Any]: return { - "next": self.next, + "next": self.next_, "alreadySetup": self.already_setup, "allowedToSetup": self.allowed_to_setup, } diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 4545f3c63..4b9eeef28 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + from __future__ import annotations from typing import Any, Dict, Optional, Set @@ -15,7 +29,6 @@ MFAClaimValue, MFARequirementList, ) -from .utils import update_and_get_mfa_related_info_in_session class HasCompletedRequirementListSCV(SessionClaimValidator): @@ -29,14 +42,10 @@ def __init__( self.claim: MultiFactorAuthClaimClass = claim self.requirement_list = requirement_list - async def should_refetch( + def should_refetch( self, payload: Dict[str, Any], user_context: Dict[str, Any] ) -> bool: - return ( - True - if self.claim.key not in payload or not payload[self.claim.key] - else False - ) + return bool(self.claim.key not in payload or not payload[self.claim.key]) async def validate( self, payload: JSONObject, user_context: Dict[str, Any] @@ -65,7 +74,7 @@ async def validate( factor_ids = next_set_of_unsatisfied_factors.factor_ids - if next_set_of_unsatisfied_factors.type == "string": + if next_set_of_unsatisfied_factors.type_ == "string": return ClaimValidationResult( is_valid=False, reason={ @@ -74,7 +83,7 @@ async def validate( }, ) - elif next_set_of_unsatisfied_factors.type == "oneOf": + elif next_set_of_unsatisfied_factors.type_ == "oneOf": return ClaimValidationResult( is_valid=False, reason={ @@ -101,15 +110,11 @@ def __init__( super().__init__(id_) self.claim = claim - async def should_refetch( + def should_refetch( self, payload: Dict[str, Any], user_context: Dict[str, Any] ) -> bool: assert self.claim is not None - return ( - True - if self.claim.key not in payload or not payload[self.claim.key] - else False - ) + return bool(self.claim.key not in payload or not payload[self.claim.key]) async def validate( self, payload: JSONObject, user_context: Dict[str, Any] @@ -161,13 +166,16 @@ def __init__(self, key: Optional[str] = None): key = key or "st-mfa" async def fetch_value( - user_id: str, + _user_id: str, recipe_user_id: RecipeUserId, tenant_id: str, current_payload: Dict[str, Any], user_context: Dict[str, Any], ) -> MFAClaimValue: + from .utils import update_and_get_mfa_related_info_in_session + mfa_info = await update_and_get_mfa_related_info_in_session( + self, input_session_recipe_user_id=recipe_user_id, input_tenant_id=tenant_id, input_access_token_payload=current_payload, @@ -209,9 +217,11 @@ def get_next_set_of_unsatisfied_factors( ) if len(next_factors) > 0: - return FactorIdsAndType(factor_ids=list(next_factors), type=factor_type) + return FactorIdsAndType( + factor_ids=list(next_factors), type_=factor_type + ) - return FactorIdsAndType(factor_ids=[], type="string") + return FactorIdsAndType(factor_ids=[], type_="string") def add_to_payload_( self, diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index f45db7316..f60c30c34 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -47,9 +47,6 @@ GetEmailsForFactorOkResult, GetPhoneNumbersForFactorsOkResult, ) -from .utils import validate_and_normalise_user_input -from .recipe_implementation import RecipeImplementation -from .api.implementation import APIImplementation from .interfaces import APIOptions @@ -79,10 +76,13 @@ def __init__( ] = [] self.is_get_mfa_requirements_for_auth_overridden: bool = False + from .utils import validate_and_normalise_user_input + self.config = validate_and_normalise_user_input( first_factors, override, ) + from .recipe_implementation import RecipeImplementation recipe_implementation = RecipeImplementation( Querier.get_instance(recipe_id), self @@ -92,6 +92,7 @@ def __init__( if self.config.override.functions is None else self.config.override.functions(recipe_implementation) ) + from .api.implementation import APIImplementation api_implementation = APIImplementation() self.api_implementation = ( diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index e2bc64251..9ae8f4a97 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -58,10 +58,10 @@ def __init__( self.factor_id = factor_id self.mfa_requirement_for_auth = mfa_requirement_for_auth - async def should_refetch( + def should_refetch( self, payload: Dict[str, Any], user_context: Dict[str, Any] ) -> bool: - return True if self.claim.get_value_from_payload(payload) is None else False + return self.claim.get_value_from_payload(payload) is None async def validate( self, payload: JSONObject, user_context: Dict[str, Any] @@ -174,6 +174,7 @@ async def mark_factor_as_complete_in_session( self, session: SessionContainer, factor_id: str, user_context: Dict[str, Any] ): await update_and_get_mfa_related_info_in_session( + MultiFactorAuthClaim, input_session=session, input_updated_factor_id=factor_id, user_context=user_context, diff --git a/supertokens_python/recipe/multifactorauth/syncio/__init__.py b/supertokens_python/recipe/multifactorauth/syncio/__init__.py index c12a6ef4a..6bd9bf9f2 100644 --- a/supertokens_python/recipe/multifactorauth/syncio/__init__.py +++ b/supertokens_python/recipe/multifactorauth/syncio/__init__.py @@ -22,7 +22,6 @@ from ..interfaces import ( MFARequirementList, ) -from ..recipe import MultiFactorAuthRecipe def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( @@ -125,6 +124,3 @@ def remove_from_required_secondary_factors_for_user( ) return sync(async_func(user_id, factor_id, user_context)) - - -init = MultiFactorAuthRecipe.init diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index d43477291..779c8ad77 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -85,10 +85,10 @@ class FactorIdsAndType: def __init__( self, factor_ids: List[str], - type: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]], + type_: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]], ): self.factor_ids = factor_ids - self.type = type + self.type_ = type_ class GetFactorsSetupForUserFromOtherRecipesFunc: diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index cd9dcbad4..d28a48175 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -12,19 +12,17 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import importlib -from typing import TYPE_CHECKING, List, Optional, Union -from typing import Dict, Any, Union, List -from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.session.exceptions import UnauthorisedError -from supertokens_python.recipe.multitenancy.asyncio import get_tenant from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe -from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.multifactorauth.types import ( MFAClaimValue, MFARequirementList, + FactorIds, ) from supertokens_python.types import RecipeUserId import math @@ -34,6 +32,12 @@ if TYPE_CHECKING: from .types import OverrideConfig, MultiFactorAuthConfig + from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaimClass, + ) + from supertokens_python.recipe.multitenancy.recipe import ( + MultitenancyRecipe as MTRecipeType, + ) def validate_and_normalise_user_input( @@ -43,10 +47,12 @@ def validate_and_normalise_user_input( if first_factors is not None and len(first_factors) == 0: raise ValueError("'first_factors' can be either None or a non-empty list") + from .types import OverrideConfig as OC, MultiFactorAuthConfig as MFAC + if override is None: - override = OverrideConfig() + override = OC() - return MultiFactorAuthConfig( + return MFAC( first_factors=first_factors, override=override, ) @@ -67,6 +73,7 @@ def __init__( async def update_and_get_mfa_related_info_in_session( + MultiFactorAuthClaim: MultiFactorAuthClaimClass, user_context: Dict[str, Any], input_session_recipe_user_id: Optional[RecipeUserId] = None, input_tenant_id: Optional[str] = None, @@ -74,9 +81,6 @@ async def update_and_get_mfa_related_info_in_session( input_session: Optional[SessionContainer] = None, input_updated_factor_id: Optional[str] = None, ) -> UpdateAndGetMFARelatedInfoInSessionResult: - from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( - MultiFactorAuthClaim, - ) from supertokens_python.recipe.multifactorauth.recipe import ( MultiFactorAuthRecipe as Recipe, ) @@ -199,7 +203,16 @@ async def user_getter(): async def get_required_secondary_factors_for_tenant( tenant_id: str, user_context: Dict[str, Any] ) -> List[str]: - tenant_info = await get_tenant(tenant_id, user_context) + + MultitenancyRecipe = importlib.import_module( + "supertokens_python.recipe.multitenancy.recipe" + ) + + mt_recipe: MTRecipeType = MultitenancyRecipe.get_instance() + + tenant_info = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) if tenant_info is None: raise UnauthorisedError("Tenant not found") return ( @@ -262,14 +275,20 @@ async def get_required_secondary_factors_for_tenant_helper() -> List[str]: async def is_valid_first_factor( tenant_id: str, factor_id: str, user_context: Dict[str, Any] ) -> Literal["OK", "INVALID_FIRST_FACTOR_ERROR", "TENANT_NOT_FOUND_ERROR"]: - tenant_info = await get_tenant(tenant_id=tenant_id, user_context=user_context) + + MultitenancyRecipe = importlib.import_module( + "supertokens_python.recipe.multitenancy.recipe" + ) + + mt_recipe: MTRecipeType = MultitenancyRecipe.get_instance() + tenant_info = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) if tenant_info is None: return "TENANT_NOT_FOUND_ERROR" tenant_config = tenant_info - mt_recipe = MultitenancyRecipe.get_instance() - first_factors_from_mfa = mt_recipe.static_first_factors log_debug_message( diff --git a/supertokens_python/recipe/multitenancy/asyncio/__init__.py b/supertokens_python/recipe/multitenancy/asyncio/__init__.py index 4f8020401..c9b1c6ce7 100644 --- a/supertokens_python/recipe/multitenancy/asyncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/asyncio/__init__.py @@ -17,6 +17,7 @@ from supertokens_python.types import RecipeUserId from ..interfaces import ( + AssociateUserToTenantNotAllowedError, TenantConfig, CreateOrUpdateTenantOkResult, DeleteTenantOkResult, @@ -31,7 +32,6 @@ DisassociateUserFromTenantOkResult, TenantConfigCreateOrUpdate, ) -from ..recipe import MultitenancyRecipe if TYPE_CHECKING: from ..interfaces import ProviderConfig @@ -44,6 +44,8 @@ async def create_or_update_tenant( ) -> CreateOrUpdateTenantOkResult: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.create_or_update_tenant( @@ -56,6 +58,8 @@ async def delete_tenant( ) -> DeleteTenantOkResult: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.delete_tenant(tenant_id, user_context) @@ -66,6 +70,8 @@ async def get_tenant( ) -> Optional[TenantConfig]: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.get_tenant(tenant_id, user_context) @@ -77,6 +83,8 @@ async def list_all_tenants( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.list_all_tenants(user_context) @@ -91,6 +99,8 @@ async def create_or_update_third_party_config( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.create_or_update_third_party_config( @@ -106,6 +116,8 @@ async def delete_third_party_config( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.delete_third_party_config( @@ -123,10 +135,13 @@ async def associate_user_to_tenant( AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.associate_user_to_tenant( @@ -134,7 +149,7 @@ async def associate_user_to_tenant( ) -async def dissociate_user_from_tenant( +async def disassociate_user_from_tenant( tenant_id: str, recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, @@ -142,8 +157,10 @@ async def dissociate_user_from_tenant( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() - return await recipe.recipe_implementation.dissociate_user_from_tenant( + return await recipe.recipe_implementation.disassociate_user_from_tenant( tenant_id, recipe_user_id, user_context ) diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 4fc5ef551..52e66b3f7 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -136,6 +136,14 @@ class AssociateUserToTenantThirdPartyUserAlreadyExistsError: status = "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR" +class AssociateUserToTenantNotAllowedError: + status = "ASSOCIATION_NOT_ALLOWED_ERROR" + + def __init__(self, reason: str): + self.status = "ASSOCIATION_NOT_ALLOWED_ERROR" + self.reason = reason + + class DisassociateUserFromTenantOkResult: status = "OK" @@ -213,11 +221,12 @@ async def associate_user_to_tenant( AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: pass @abstractmethod - async def dissociate_user_from_tenant( + async def disassociate_user_from_tenant( self, tenant_id: str, recipe_user_id: RecipeUserId, diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 432ffc07d..c60f148f8 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -25,6 +25,7 @@ from supertokens_python.types import RecipeUserId from .interfaces import ( + AssociateUserToTenantNotAllowedError, RecipeInterface, TenantConfig, CreateOrUpdateTenantOkResult, @@ -254,6 +255,7 @@ async def associate_user_to_tenant( AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: response = await self.querier.send_post_request( NormalisedURLPath( @@ -283,10 +285,12 @@ async def associate_user_to_tenant( == AssociateUserToTenantThirdPartyUserAlreadyExistsError.status ): return AssociateUserToTenantThirdPartyUserAlreadyExistsError() + if response["status"] == AssociateUserToTenantNotAllowedError.status: + return AssociateUserToTenantNotAllowedError(response["reason"]) raise Exception("Should never come here") - async def dissociate_user_from_tenant( + async def disassociate_user_from_tenant( self, tenant_id: Optional[str], recipe_user_id: RecipeUserId, diff --git a/supertokens_python/recipe/multitenancy/syncio/__init__.py b/supertokens_python/recipe/multitenancy/syncio/__init__.py index 7384ee6bd..5448f2612 100644 --- a/supertokens_python/recipe/multitenancy/syncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/syncio/__init__.py @@ -108,7 +108,7 @@ def associate_user_to_tenant( return sync(associate_user_to_tenant(tenant_id, recipe_user_id, user_context)) -def dissociate_user_from_tenant( +def disassociate_user_from_tenant( tenant_id: str, recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, @@ -117,7 +117,7 @@ def dissociate_user_from_tenant( user_context = {} from supertokens_python.recipe.multitenancy.asyncio import ( - dissociate_user_from_tenant, + disassociate_user_from_tenant, ) - return sync(dissociate_user_from_tenant(tenant_id, recipe_user_id, user_context)) + return sync(disassociate_user_from_tenant(tenant_id, recipe_user_id, user_context)) diff --git a/supertokens_python/recipe/session/exceptions.py b/supertokens_python/recipe/session/exceptions.py index 637cf486c..d9eedddc9 100644 --- a/supertokens_python/recipe/session/exceptions.py +++ b/supertokens_python/recipe/session/exceptions.py @@ -87,12 +87,15 @@ def __init__(self, msg: str, payload: List[ClaimValidationError]): class ClaimValidationError: - def __init__(self, id_: str, reason: Optional[Dict[str, Any]]): - self.id = id_ - self.reason = reason + id_: str + reason: Optional[Union[str, Dict[str, Any]]] + + def __init__(self, id_: str, reason: Optional[Union[str, Dict[str, Any]]]): + self.id_: str = id_ + self.reason: Optional[Union[str, Dict[str, Any]]] = reason def to_json(self): - result: Dict[str, Any] = {"id": self.id} + result: Dict[str, Any] = {"id": self.id_} if self.reason is not None: result["reason"] = self.reason diff --git a/supertokens_python/recipe/thirdparty/asyncio/__init__.py b/supertokens_python/recipe/thirdparty/asyncio/__init__.py index e4735ac76..b374e041b 100644 --- a/supertokens_python/recipe/thirdparty/asyncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/asyncio/__init__.py @@ -30,7 +30,7 @@ async def manually_create_or_update_user( third_party_user_id: str, email: str, is_verified: bool, - session: Optional[SessionContainer], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ManuallyCreateOrUpdateUserOkResult, diff --git a/supertokens_python/recipe/thirdparty/syncio/__init__.py b/supertokens_python/recipe/thirdparty/syncio/__init__.py index 429b4b027..4481c8131 100644 --- a/supertokens_python/recipe/thirdparty/syncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/syncio/__init__.py @@ -29,7 +29,7 @@ def manually_create_or_update_user( third_party_user_id: str, email: str, is_verified: bool, - session: Optional[SessionContainer], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ManuallyCreateOrUpdateUserOkResult, diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py index 3cb4d484e..64c45f783 100644 --- a/supertokens_python/recipe/totp/interfaces.py +++ b/supertokens_python/recipe/totp/interfaces.py @@ -13,11 +13,26 @@ # under the License. from __future__ import annotations -from typing import Dict, Any, Union, TYPE_CHECKING +from typing import Dict, Any, Union, TYPE_CHECKING, Optional from abc import ABC, abstractmethod if TYPE_CHECKING: - from .types import * + from .types import ( + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UpdateDeviceOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + VerifyTOTPOkResult, + InvalidTOTPError, + LimitReachedError, + UnknownDeviceError, + ListDevicesOkResult, + TOTPNormalisedConfig, + ) from supertokens_python.recipe.session import SessionContainer from supertokens_python import AppInfo from supertokens_python.framework import BaseRequest, BaseResponse diff --git a/supertokens_python/recipe/totp/syncio/__init__.py b/supertokens_python/recipe/totp/syncio/__init__.py index 2eaecf37f..24eb4d2ed 100644 --- a/supertokens_python/recipe/totp/syncio/__init__.py +++ b/supertokens_python/recipe/totp/syncio/__init__.py @@ -18,7 +18,6 @@ from supertokens_python.async_to_sync_wrapper import sync -from ..recipe import TOTPRecipe from supertokens_python.recipe.totp.types import ( CreateDeviceOkResult, DeviceAlreadyExistsError, @@ -124,6 +123,3 @@ def verify_totp( from supertokens_python.recipe.totp.asyncio import verify_totp as async_func return sync(async_func(tenant_id, user_id, totp, user_context)) - - -init = TOTPRecipe.init diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py index fae156fe8..f9599bf65 100644 --- a/supertokens_python/recipe/totp/types.py +++ b/supertokens_python/recipe/totp/types.py @@ -75,9 +75,6 @@ def to_json(self) -> Dict[str, Any]: class UpdateDeviceOkResult(OkResult): - def __init__(self): - super().__init__() - def to_json(self) -> Dict[str, Any]: raise NotImplementedError() @@ -174,11 +171,6 @@ def to_json(self) -> Dict[str, Any]: class VerifyTOTPOkResult(OkResult): - def __init__( - self, - ): - super().__init__() - def to_json(self) -> Dict[str, Any]: return {"status": self.status} From 0cfec6f26a73c71e712f10d4c6584bb5eabc6ac4 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 16:13:13 +0530 Subject: [PATCH 065/126] fixes all cyclic imports --- supertokens_python/auth_utils.py | 3 ++- .../recipe/multifactorauth/api/implementation.py | 9 +++++---- .../recipe/multifactorauth/multi_factor_auth_claim.py | 7 +++++-- supertokens_python/recipe/multifactorauth/recipe.py | 10 +++++++--- .../recipe/multifactorauth/recipe_implementation.py | 8 ++++++-- .../recipe/multitenancy/api/implementation.py | 9 +++++++-- supertokens_python/supertokens.py | 7 +++++-- 7 files changed, 37 insertions(+), 16 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index ff3315ba3..08d77cb4d 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -912,12 +912,13 @@ async def filter_out_invalid_second_factors_or_throw_if_all_are_invalid( factors_set_up_for_user_prom: Optional[List[str]] = None mfa_info_prom = None - async def get_factors_set_up_for_user(): + async def get_factors_set_up_for_user() -> List[str]: nonlocal factors_set_up_for_user_prom if factors_set_up_for_user_prom is None: factors_set_up_for_user_prom = await mfa_instance.recipe_implementation.get_factors_setup_for_user( user=session_user, user_context=user_context ) + assert factors_set_up_for_user_prom is not None return factors_set_up_for_user_prom async def get_mfa_requirements_for_auth(): diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py index d6fe15e98..dea5fe91f 100644 --- a/supertokens_python/recipe/multifactorauth/api/implementation.py +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -17,9 +17,6 @@ from typing import Any, Dict, List, Union, TYPE_CHECKING from supertokens_python.recipe.session import SessionContainer -from supertokens_python.recipe.multifactorauth.utils import ( - update_and_get_mfa_related_info_in_session, -) from supertokens_python.recipe.multitenancy.asyncio import get_tenant from supertokens_python.asyncio import get_user from supertokens_python.recipe.session.exceptions import ( @@ -54,6 +51,10 @@ async def resync_session_and_fetch_mfa_info_put( MultiFactorAuthClaim: MultiFactorAuthClaimType = mfa.MultiFactorAuthClaim + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + session_user = await get_user(session.get_user_id(), user_context) if session_user is None: @@ -61,7 +62,7 @@ async def resync_session_and_fetch_mfa_info_put( "Session user not found", ) - mfa_info = await update_and_get_mfa_related_info_in_session( + mfa_info = await module.update_and_get_mfa_related_info_in_session( MultiFactorAuthClaim, input_session=session, user_context=user_context, diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 4b9eeef28..2780bd5f4 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -13,6 +13,7 @@ # under the License. from __future__ import annotations +import importlib from typing import Any, Dict, Optional, Set @@ -172,9 +173,11 @@ async def fetch_value( current_payload: Dict[str, Any], user_context: Dict[str, Any], ) -> MFAClaimValue: - from .utils import update_and_get_mfa_related_info_in_session + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) - mfa_info = await update_and_get_mfa_related_info_in_session( + mfa_info = await module.update_and_get_mfa_related_info_in_session( self, input_session_recipe_user_id=recipe_user_id, input_tenant_id=tenant_id, diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py index f60c30c34..139826761 100644 --- a/supertokens_python/recipe/multifactorauth/recipe.py +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import importlib from os import environ from typing import Any, Dict, Optional, List, Union @@ -31,7 +32,6 @@ MultiFactorAuthClaim, ) from supertokens_python.recipe.multitenancy.interfaces import TenantConfig -from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo @@ -76,9 +76,11 @@ def __init__( ] = [] self.is_get_mfa_requirements_for_auth_overridden: bool = False - from .utils import validate_and_normalise_user_input + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) - self.config = validate_and_normalise_user_input( + self.config = module.validate_and_normalise_user_input( first_factors, override, ) @@ -102,6 +104,8 @@ def __init__( ) def callback(): + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + mt_recipe = MultitenancyRecipe.get_instance() mt_recipe.static_first_factors = self.config.first_factors diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index 9ae8f4a97..f2cd77d46 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import importlib from typing import TYPE_CHECKING, Any, Awaitable, Dict, Set, Callable, List @@ -35,7 +36,6 @@ from supertokens_python.recipe.session import SessionContainer from supertokens_python.types import User -from .utils import update_and_get_mfa_related_info_in_session from .interfaces import RecipeInterface if TYPE_CHECKING: @@ -173,7 +173,11 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( async def mark_factor_as_complete_in_session( self, session: SessionContainer, factor_id: str, user_context: Dict[str, Any] ): - await update_and_get_mfa_related_info_in_session( + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + + await module.update_and_get_mfa_related_info_in_session( MultiFactorAuthClaim, input_session=session, input_updated_factor_id=factor_id, diff --git a/supertokens_python/recipe/multitenancy/api/implementation.py b/supertokens_python/recipe/multitenancy/api/implementation.py index 27a58624f..565773e76 100644 --- a/supertokens_python/recipe/multitenancy/api/implementation.py +++ b/supertokens_python/recipe/multitenancy/api/implementation.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import importlib from typing import Any, Dict, Optional, Union, List from ..constants import DEFAULT_TENANT_ID @@ -35,7 +36,9 @@ async def login_methods_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[LoginMethodsGetOkResult, GeneralErrorResponse]: - from ...multifactorauth.utils import is_valid_first_factor + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) from supertokens_python.recipe.thirdparty.providers.config_utils import ( merge_providers_from_core_and_static, find_and_create_provider_instance, @@ -91,7 +94,9 @@ async def login_methods_get( valid_first_factors: List[str] = [] for factor_id in first_factors: - valid_res = await is_valid_first_factor(tenant_id, factor_id, user_context) + valid_res = await module.is_valid_first_factor( + tenant_id, factor_id, user_context + ) if valid_res == "OK": valid_first_factors.append(factor_id) if valid_res == "TENANT_NOT_FOUND_ERROR": diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 6aae8aa9b..51ec74b81 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -13,6 +13,7 @@ # under the License. from __future__ import annotations +import importlib from os import environ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, Tuple @@ -276,9 +277,11 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) if not multitenancy_found: - from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + module = importlib.import_module( + "supertokens_python.recipe.multitenancy.recipe" + ) - self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + self.recipe_modules.append(module.init()(self.app_info)) if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") if not user_metadata_found: From 9e78f9993b6f895a83799c9b13f82f93ca2bc15a Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 16:24:06 +0530 Subject: [PATCH 066/126] small changes --- .../multifactorauth/api/implementation.py | 14 +++------- .../recipe/multifactorauth/utils.py | 28 ++++--------------- supertokens_python/supertokens.py | 7 ++--- 3 files changed, 11 insertions(+), 38 deletions(-) diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py index dea5fe91f..e16f422b8 100644 --- a/supertokens_python/recipe/multifactorauth/api/implementation.py +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -14,7 +14,10 @@ from __future__ import annotations import importlib -from typing import Any, Dict, List, Union, TYPE_CHECKING +from typing import Any, Dict, List, Union +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.multitenancy.asyncio import get_tenant @@ -33,11 +36,6 @@ ResyncSessionAndFetchMFAInfoPUTOkResult, ) -if TYPE_CHECKING: - from ..multi_factor_auth_claim import ( - MultiFactorAuthClaimClass as MultiFactorAuthClaimType, - ) - class APIImplementation(APIInterface): async def resync_session_and_fetch_mfa_info_put( @@ -47,10 +45,6 @@ async def resync_session_and_fetch_mfa_info_put( user_context: Dict[str, Any], ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: - mfa = importlib.import_module("supertokens_python.recipe.multifactorauth") - - MultiFactorAuthClaim: MultiFactorAuthClaimType = mfa.MultiFactorAuthClaim - module = importlib.import_module( "supertokens_python.recipe.multifactorauth.utils" ) diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index d28a48175..c4d54cd50 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -12,9 +12,9 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations -import importlib - from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.recipe.session.exceptions import UnauthorisedError @@ -35,9 +35,6 @@ from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( MultiFactorAuthClaimClass, ) - from supertokens_python.recipe.multitenancy.recipe import ( - MultitenancyRecipe as MTRecipeType, - ) def validate_and_normalise_user_input( @@ -203,16 +200,7 @@ async def user_getter(): async def get_required_secondary_factors_for_tenant( tenant_id: str, user_context: Dict[str, Any] ) -> List[str]: - - MultitenancyRecipe = importlib.import_module( - "supertokens_python.recipe.multitenancy.recipe" - ) - - mt_recipe: MTRecipeType = MultitenancyRecipe.get_instance() - - tenant_info = await mt_recipe.recipe_implementation.get_tenant( - tenant_id=tenant_id, user_context=user_context - ) + tenant_info = await get_tenant(tenant_id, user_context) if tenant_info is None: raise UnauthorisedError("Tenant not found") return ( @@ -276,14 +264,8 @@ async def is_valid_first_factor( tenant_id: str, factor_id: str, user_context: Dict[str, Any] ) -> Literal["OK", "INVALID_FIRST_FACTOR_ERROR", "TENANT_NOT_FOUND_ERROR"]: - MultitenancyRecipe = importlib.import_module( - "supertokens_python.recipe.multitenancy.recipe" - ) - - mt_recipe: MTRecipeType = MultitenancyRecipe.get_instance() - tenant_info = await mt_recipe.recipe_implementation.get_tenant( - tenant_id=tenant_id, user_context=user_context - ) + mt_recipe = MultitenancyRecipe.get_instance() + tenant_info = await get_tenant(tenant_id=tenant_id, user_context=user_context) if tenant_info is None: return "TENANT_NOT_FOUND_ERROR" diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index 51ec74b81..6aae8aa9b 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -13,7 +13,6 @@ # under the License. from __future__ import annotations -import importlib from os import environ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, Tuple @@ -277,11 +276,9 @@ def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) if not multitenancy_found: - module = importlib.import_module( - "supertokens_python.recipe.multitenancy.recipe" - ) + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe - self.recipe_modules.append(module.init()(self.app_info)) + self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) if totp_found and not multi_factor_auth_found: raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") if not user_metadata_found: From 5d0c4d395eefdc91ebee315174a76e03cf6f75bf Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 16:29:23 +0530 Subject: [PATCH 067/126] fixes stuff --- supertokens_python/auth_utils.py | 4 ---- .../recipe/multifactorauth/api/implementation.py | 1 - .../recipe/multifactorauth/asyncio/__init__.py | 6 ------ .../recipe/multifactorauth/multi_factor_auth_claim.py | 1 - .../recipe/multifactorauth/recipe_implementation.py | 1 - supertokens_python/recipe/multifactorauth/utils.py | 7 +++---- 6 files changed, 3 insertions(+), 17 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 08d77cb4d..7aa7a9709 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -924,12 +924,8 @@ async def get_factors_set_up_for_user() -> List[str]: async def get_mfa_requirements_for_auth(): nonlocal mfa_info_prom if mfa_info_prom is None: - from .recipe.multifactorauth.multi_factor_auth_claim import ( - MultiFactorAuthClaim, - ) mfa_info_prom = await update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim, input_session=session, user_context=user_context, ) diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py index e16f422b8..edb888256 100644 --- a/supertokens_python/recipe/multifactorauth/api/implementation.py +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -57,7 +57,6 @@ async def resync_session_and_fetch_mfa_info_put( ) mfa_info = await module.update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim, input_session=session, user_context=user_context, ) diff --git a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py index 4c8af2e54..8f51ced5b 100644 --- a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py +++ b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py @@ -33,10 +33,7 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( if user_context is None: user_context = {} - from ..multi_factor_auth_claim import MultiFactorAuthClaim - mfa_info = await update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim, input_session=session, user_context=user_context, ) @@ -69,10 +66,7 @@ async def get_mfa_requirements_for_auth( if user_context is None: user_context = {} - from ..multi_factor_auth_claim import MultiFactorAuthClaim - mfa_info = await update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim, input_session=session, user_context=user_context, ) diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 2780bd5f4..4193697ac 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -178,7 +178,6 @@ async def fetch_value( ) mfa_info = await module.update_and_get_mfa_related_info_in_session( - self, input_session_recipe_user_id=recipe_user_id, input_tenant_id=tenant_id, input_access_token_payload=current_payload, diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index f2cd77d46..fc5cd1c90 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -178,7 +178,6 @@ async def mark_factor_as_complete_in_session( ) await module.update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim, input_session=session, input_updated_factor_id=factor_id, user_context=user_context, diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index c4d54cd50..7b32b92ac 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -13,6 +13,9 @@ # under the License. from __future__ import annotations from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) from supertokens_python.recipe.multitenancy.asyncio import get_tenant from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.session import SessionContainer @@ -32,9 +35,6 @@ if TYPE_CHECKING: from .types import OverrideConfig, MultiFactorAuthConfig - from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( - MultiFactorAuthClaimClass, - ) def validate_and_normalise_user_input( @@ -70,7 +70,6 @@ def __init__( async def update_and_get_mfa_related_info_in_session( - MultiFactorAuthClaim: MultiFactorAuthClaimClass, user_context: Dict[str, Any], input_session_recipe_user_id: Optional[RecipeUserId] = None, input_tenant_id: Optional[str] = None, From 91dcf5728473ee5837e3e9da8cf7750ced039ebe Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 16:36:22 +0530 Subject: [PATCH 068/126] adds comments --- supertokens_python/recipe/multifactorauth/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py index 7b32b92ac..12f9ea0f6 100644 --- a/supertokens_python/recipe/multifactorauth/utils.py +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -37,6 +37,8 @@ from .types import OverrideConfig, MultiFactorAuthConfig +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. def validate_and_normalise_user_input( first_factors: Optional[List[str]], override: Union[OverrideConfig, None] = None, @@ -69,6 +71,8 @@ def __init__( ) +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. async def update_and_get_mfa_related_info_in_session( user_context: Dict[str, Any], input_session_recipe_user_id: Optional[RecipeUserId] = None, @@ -259,6 +263,8 @@ async def get_required_secondary_factors_for_tenant_helper() -> List[str]: ) +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. async def is_valid_first_factor( tenant_id: str, factor_id: str, user_context: Dict[str, Any] ) -> Literal["OK", "INVALID_FIRST_FACTOR_ERROR", "TENANT_NOT_FOUND_ERROR"]: From fa03181d5cce581e9d17e6aed2e026ca107dfb1f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 19:52:48 +0530 Subject: [PATCH 069/126] fixes a few bugs --- .../recipe/accountlinking/utils.py | 2 +- .../multi_factor_auth_claim.py | 8 +++- tests/auth-react/flask-server/app.py | 40 +++++++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py index 5f0209fd1..6aa12765a 100644 --- a/supertokens_python/recipe/accountlinking/utils.py +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -52,7 +52,7 @@ async def default_should_do_automatic_account_linking( def recipe_init_defined_should_do_automatic_account_linking() -> bool: - return _did_use_default_should_do_automatic_account_linking + return not _did_use_default_should_do_automatic_account_linking def validate_and_normalise_user_input( diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 4193697ac..3c4bd6551 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -59,7 +59,9 @@ async def validate( "This should never happen, claim value not present in payload" ) - claim_val: MFAClaimValue = payload[self.claim.key] + claim_val: MFAClaimValue = MFAClaimValue( + c=payload[self.claim.key]["c"], v=payload[self.claim.key]["v"] + ) completed_factors = claim_val.c next_set_of_unsatisfied_factors = ( @@ -125,7 +127,9 @@ async def validate( raise Exception( "This should never happen, claim value not present in payload" ) - claim_val: MFAClaimValue = payload[self.claim.key] + claim_val: MFAClaimValue = MFAClaimValue( + c=payload[self.claim.key]["c"], v=payload[self.claim.key]["v"] + ) return ClaimValidationResult( is_valid=claim_val.v, diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 55ed847c8..fabef9b3f 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -321,6 +321,30 @@ async def get_user_info( # pylint: disable=no-self-use return oi +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + def custom_init(): global contact_method global flow_type @@ -713,6 +737,21 @@ async def resend_code_post( ), override=auth0_provider_override, ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), ] global enabled_providers @@ -1386,6 +1425,7 @@ def index(_: str): @app.errorhandler(Exception) # type: ignore def all_exception_handler(e: Exception): + print(e) return "Error", 500 From 12a624cecd1313f58e2345548573ede076ecb724 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 21:22:29 +0530 Subject: [PATCH 070/126] fixes bugs --- supertokens_python/auth_utils.py | 8 +++--- .../recipe/emailpassword/api/signup.py | 25 +++++++++++-------- .../recipe/thirdparty/api/implementation.py | 2 +- supertokens_python/types.py | 7 +++--- tests/auth-react/flask-server/app.py | 2 ++ 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 7aa7a9709..7e1bcff7c 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -42,7 +42,7 @@ class LinkingToSessionUserFailedError: - status: Literal["LINKING_TO_SESSION_USER_FAILED"] + status: Literal["LINKING_TO_SESSION_USER_FAILED"] = "LINKING_TO_SESSION_USER_FAILED" reason: Literal[ "EMAIL_VERIFICATION_REQUIRED", "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", @@ -76,11 +76,11 @@ def __init__(self, valid_factor_ids: List[str], is_first_factor: bool): class SignUpNotAllowedResponse: - status: Literal["SIGN_UP_NOT_ALLOWED"] + status: Literal["SIGN_UP_NOT_ALLOWED"] = "SIGN_UP_NOT_ALLOWED" class SignInNotAllowedResponse: - status: Literal["SIGN_IN_NOT_ALLOWED"] + status: Literal["SIGN_IN_NOT_ALLOWED"] = "SIGN_IN_NOT_ALLOWED" async def pre_auth_checks( @@ -325,7 +325,7 @@ async def get_authenticating_user_and_add_to_current_tenant_if_required( for lm in user.login_methods if lm.recipe_id == recipe_id and ( - lm.has_same_email_as(email) + (email is not None and lm.has_same_email_as(email)) or lm.has_same_phone_number_as(phone_number) or lm.has_same_third_party_info_as(third_party) ) diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index 13c9aacd1..21b2ec9bc 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -16,7 +16,10 @@ from typing import TYPE_CHECKING, Any, Dict from supertokens_python.auth_utils import load_session_in_auth_api_if_needed -from supertokens_python.recipe.emailpassword.interfaces import SignUpPostOkResult +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + SignUpPostOkResult, +) from supertokens_python.types import GeneralErrorResponse from ..exceptions import raise_form_field_exception @@ -93,12 +96,14 @@ async def handle_sign_up_api( if isinstance(response, GeneralErrorResponse): return send_200_response(response.to_json(), api_options.response) - return raise_form_field_exception( - "EMAIL_ALREADY_EXISTS_ERROR", - [ - ErrorFormField( - id="email", - error="This email already exists. Please sign in instead.", - ) - ], - ) + if isinstance(response, EmailAlreadyExistsError): + return raise_form_field_exception( + "EMAIL_ALREADY_EXISTS_ERROR", + [ + ErrorFormField( + id="email", + error="This email already exists. Please sign in instead.", + ) + ], + ) + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index f5266191f..75713db13 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -201,7 +201,7 @@ async def check_credentials_on_tenant(_: str): if not isinstance(pre_auth_checks_result, OkResponse): if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): - reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] assert isinstance(reason, str) return SignInUpNotAllowed(reason) if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 936f6f191..939868e71 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -109,10 +109,11 @@ def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: def has_same_third_party_info_as( self, third_party: Union[ThirdPartyInfo, None] ) -> bool: - if third_party is None or self.third_party is None: + if third_party is None: return False return ( - self.third_party.id.strip() == third_party.id.strip() + self.third_party is not None + and self.third_party.id.strip() == third_party.id.strip() and self.third_party.user_id.strip() == third_party.user_id.strip() ) @@ -140,7 +141,7 @@ def from_json(json: Dict[str, Any]) -> "LoginMethod": phone_number=json["phoneNumber"] if "phoneNumber" in json else None, third_party=( ( - TPI(json["thirdParty"]["id"], json["thirdParty"]["userId"]) + TPI(json["thirdParty"]["userId"], json["thirdParty"]["id"]) if "thirdParty" in json else None ) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index fabef9b3f..ddb3b4b7a 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. import os +import traceback from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv @@ -1426,6 +1427,7 @@ def index(_: str): @app.errorhandler(Exception) # type: ignore def all_exception_handler(e: Exception): print(e) + print(traceback.format_exc()) return "Error", 500 From d31a8fb115552e35fad137a260e16107e03ce201 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 30 Sep 2024 21:39:49 +0530 Subject: [PATCH 071/126] fixes bugs --- .../recipe/passwordless/api/implementation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 17aaaeba7..fb6b49db0 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -214,13 +214,11 @@ async def create_code_post( if not isinstance(pre_auth_checks_result, OkResponse): if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): - reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] assert isinstance(reason, str) return SignInUpPostNotAllowedResponse(reason) if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): - reason = error_code_map["SIGN_IN_NOT_ALLOWED"] - assert isinstance(reason, str) - return SignInUpPostNotAllowedResponse(reason) + raise Exception("Should never come here") reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] assert isinstance(reason_dict, Dict) From 37b41d42d58b8e48f65db5303c1114bf27f58c0c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 1 Oct 2024 16:52:40 +0530 Subject: [PATCH 072/126] fixes more tests --- dev-requirements.txt | 1 + supertokens_python/auth_utils.py | 12 ++++---- .../multi_factor_auth_claim.py | 5 +++- .../recipe/totp/recipe_implementation.py | 3 +- tests/auth-react/flask-server/app.py | 29 ++++++++++++++++++- 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 4291ad283..0694d481a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -85,3 +85,4 @@ uvicorn==0.18.2 Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 +pyotp==2.9.0 \ No newline at end of file diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index 7e1bcff7c..e471ea522 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -452,9 +452,9 @@ class OkFirstFactorResponse: class OkSecondFactorLinkedResponse: - status: Literal["OK"] - is_first_factor: Literal[False] - input_user_already_linked_to_session_user: Literal[True] + status: Literal["OK"] = "OK" + is_first_factor: Literal[False] = False + input_user_already_linked_to_session_user: Literal[True] = True session_user: User def __init__(self, session_user: User): @@ -462,9 +462,9 @@ def __init__(self, session_user: User): class OkSecondFactorNotLinkedResponse: - status: Literal["OK"] - is_first_factor: Literal[False] - input_user_already_linked_to_session_user: Literal[False] + status: Literal["OK"] = "OK" + is_first_factor: Literal[False] = False + input_user_already_linked_to_session_user: Literal[False] = False session_user: User linking_to_session_user_requires_verification: bool diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py index 3c4bd6551..8b45108ba 100644 --- a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -259,7 +259,10 @@ def remove_from_payload_by_merge_( def get_value_from_payload( self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> Optional[MFAClaimValue]: - return payload.get(self.key) + value = payload.get(self.key) + if value is None: + return None + return MFAClaimValue(c=value["c"], v=value["v"]) MultiFactorAuthClaim = MultiFactorAuthClaimClass() diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py index b213c4eb2..ce213dd88 100644 --- a/supertokens_python/recipe/totp/recipe_implementation.py +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -107,10 +107,11 @@ async def create_device( data = { "userId": user_id, - "deviceName": device_name, "skew": skew if skew is not None else self.config.default_skew, "period": period if period is not None else self.config.default_period, } + if device_name is not None: + data["deviceName"] = device_name response = await self.querier.send_post_request( NormalisedURLPath("/recipe/totp/device"), data, diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index ddb3b4b7a..913183655 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -174,7 +174,7 @@ def get_website_domain(): os.environ.setdefault("SUPERTOKENS_ENV", "testing") -latest_url_with_token = None +latest_url_with_token = "" code_store: Dict[str, List[Dict[str, Any]]] = {} accountlinking_config: Dict[str, Any] = {} @@ -1286,6 +1286,8 @@ def before_each(): global enabled_providers global enabled_recipes global mfa_info + global latest_url_with_token + latest_url_with_token = "" code_store = dict() accountlinking_config = {} enabled_providers = None @@ -1319,6 +1321,16 @@ def test_set_account_linking_config(): return "", 200 +@app.route("/setMFAInfo", methods=["POST"]) # type: ignore +def set_mfa_info(): + global mfa_info + body = request.get_json() + if body is None: + return jsonify({"error": "Invalid request body"}), 400 + mfa_info = body + return jsonify({"status": "OK"}) + + @app.route("/test/setEnabledRecipes", methods=["POST"]) # type: ignore def test_set_enabled_recipes(): global enabled_recipes @@ -1332,6 +1344,21 @@ def test_set_enabled_recipes(): return "", 200 +@app.route("/test/getTOTPCode", methods=["POST"]) # type: ignore +def test_get_totp_code(): + from pyotp import TOTP + + body = request.get_json() + if body is None or "secret" not in body: + return jsonify({"error": "Invalid request body"}), 400 + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return jsonify({"totp": code}) + + @app.get("/test/getDevice") # type: ignore def test_get_device(): global code_store From cdcc4fcd4e05815b21e2b56c5633ffd09ee30f3f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 1 Oct 2024 16:54:14 +0530 Subject: [PATCH 073/126] fixes dependencies for cicd --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 41143697e..1b3b89397 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ "Fastapi", "uvicorn==0.18.2", "python-dotenv==0.19.2", + "pyotp==2.9.0", ] ), "flask": ( @@ -26,6 +27,7 @@ "flask_cors", "Flask", "python-dotenv==0.19.2", + "pyotp==2.9.0", ] ), "django": ( @@ -35,6 +37,7 @@ "django-stubs==1.9.0", "uvicorn==0.18.2", "python-dotenv==0.19.2", + "pyotp==2.9.0", ] ), "django2x": ( @@ -44,6 +47,7 @@ "django-stubs==1.9.0", "gunicorn==20.1.0", "python-dotenv==0.19.2", + "pyotp==2.9.0", ] ), "drf": ( @@ -57,6 +61,7 @@ "uvicorn==0.18.2", "python-dotenv==0.19.2", "tzdata==2021.5", + "pyotp==2.9.0", ] ), } From 61cfea3d1a10acce3d01df880adeb1c248d13372 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 1 Oct 2024 22:32:02 +0530 Subject: [PATCH 074/126] fixes more tests --- tests/auth-react/flask-server/app.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 913183655..0af55857c 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -828,8 +828,8 @@ async def get_factors_setup_for_user( user_context: Dict[str, Any], ): res = await og_get_factors_setup_for_user(user, user_context) - if mfa_info.get("alreadySetup"): - return mfa_info.get("alreadySetup", []) + if "alreadySetup" in mfa_info: + return mfa_info["alreadySetup"] return res og_assert_allowed_to_setup_factor = ( @@ -843,7 +843,7 @@ async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( factors_set_up_for_user: Callable[[], Awaitable[List[str]]], user_context: Dict[str, Any], ): - if mfa_info.get("allowedToSetup"): + if "allowedToSetup" in mfa_info: if factor_id not in mfa_info["allowedToSetup"]: raise InvalidClaimsError( msg="INVALID_CLAIMS", @@ -884,7 +884,7 @@ async def get_mfa_requirements_for_auth( required_secondary_factors_for_tenant, user_context, ) - if mfa_info.get("requirements"): + if "requirements" in mfa_info: return mfa_info["requirements"] return res @@ -914,10 +914,10 @@ async def resync_session_and_fetch_mfa_info_put( ) if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): - if mfa_info.get("alreadySetup"): + if "alreadySetup" in mfa_info: res.factors.already_setup = mfa_info["alreadySetup"][:] - if mfa_info.get("noContacts"): + if "noContacts" in mfa_info: res.emails = {} res.phone_numbers = {} From 520a30bd735ed7f8dff29781eb31eaa0d8b19aa0 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 1 Oct 2024 23:28:57 +0530 Subject: [PATCH 075/126] more fixes --- .../recipe/multifactorauth/syncio/__init__.py | 6 +----- tests/auth-react/flask-server/app.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/supertokens_python/recipe/multifactorauth/syncio/__init__.py b/supertokens_python/recipe/multifactorauth/syncio/__init__.py index 6bd9bf9f2..8268fd19a 100644 --- a/supertokens_python/recipe/multifactorauth/syncio/__init__.py +++ b/supertokens_python/recipe/multifactorauth/syncio/__init__.py @@ -19,10 +19,6 @@ from supertokens_python.recipe.session import SessionContainer from supertokens_python.async_to_sync_wrapper import sync -from ..interfaces import ( - MFARequirementList, -) - def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( session: SessionContainer, @@ -42,7 +38,7 @@ def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( def get_mfa_requirements_for_auth( session: SessionContainer, user_context: Optional[Dict[str, Any]] = None, -) -> MFARequirementList: +): if user_context is None: user_context = {} diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 0af55857c..a4310c4c1 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -57,6 +57,9 @@ ResyncSessionAndFetchMFAInfoPUTOkResult, ) from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.syncio import ( + add_to_required_secondary_factors_for_user, +) from supertokens_python.recipe.multifactorauth.types import MFARequirementList from supertokens_python.recipe.multitenancy.interfaces import ( AssociateUserToTenantEmailAlreadyExistsError, @@ -1331,6 +1334,20 @@ def set_mfa_info(): return jsonify({"status": "OK"}) +@app.route("/addRequiredFactor", methods=["POST"]) # type: ignore +@verify_session() +def add_required_factor(): + session_: SessionContainer = g.supertokens # type: ignore + + body = request.get_json() + if body is None or "factorId" not in body: + return jsonify({"error": "Invalid request body"}), 400 + + add_to_required_secondary_factors_for_user(session_.get_user_id(), body["factorId"]) + + return jsonify({"status": "OK"}) + + @app.route("/test/setEnabledRecipes", methods=["POST"]) # type: ignore def test_set_enabled_recipes(): global enabled_recipes From 8c6f3316fe135711703f4b80fb06e515b1cb3fb2 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 2 Oct 2024 13:45:24 +0530 Subject: [PATCH 076/126] fixes more issues --- .../recipe/thirdparty/provider.py | 36 +++++++++++-------- .../recipe/thirdparty/providers/custom.py | 8 ++--- tests/auth-react/flask-server/app.py | 20 +++++------ 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index 769343a33..44e426cd3 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -144,7 +144,9 @@ def to_json(self) -> Dict[str, Any]: return {k: v for k, v in res.items() if v is not None} @staticmethod - def from_json(json: Dict[str, Any]) -> UserFields: + def from_json(json: Optional[Dict[str, Any]]) -> Optional[UserFields]: + if json is None: + return None return UserFields( user_id=json.get("userId", None), email=json.get("email", None), @@ -170,12 +172,14 @@ def to_json(self) -> Dict[str, Any]: return res @staticmethod - def from_json(json: Dict[str, Any]) -> UserInfoMap: + def from_json(json: Optional[Dict[str, Any]]) -> Optional[UserInfoMap]: + if json is None: + return None return UserInfoMap( from_id_token_payload=UserFields.from_json( - json.get("fromIdTokenPayload", {}) + json.get("fromIdTokenPayload", None) ), - from_user_info_api=UserFields.from_json(json.get("fromUserInfoAPI", {})), + from_user_info_api=UserFields.from_json(json.get("fromUserInfoAPI", None)), ) @@ -394,22 +398,24 @@ def to_json(self) -> Dict[str, Any]: def from_json(json: Dict[str, Any]) -> ProviderConfig: return ProviderConfig( third_party_id=json.get("thirdPartyId", ""), - name=json.get("name", ""), + name=json.get("name", None), clients=[ ProviderClientConfig.from_json(c) for c in json.get("clients", []) ], - authorization_endpoint=json.get("authorizationEndpoint", ""), + authorization_endpoint=json.get("authorizationEndpoint", None), authorization_endpoint_query_params=json.get( - "authorizationEndpointQueryParams", {} + "authorizationEndpointQueryParams", None ), - token_endpoint=json.get("tokenEndpoint", ""), - token_endpoint_body_params=json.get("tokenEndpointBodyParams", {}), - user_info_endpoint=json.get("userInfoEndpoint", ""), - user_info_endpoint_query_params=json.get("userInfoEndpointQueryParams", {}), - user_info_endpoint_headers=json.get("userInfoEndpointHeaders", {}), - jwks_uri=json.get("jwksURI", ""), - oidc_discovery_endpoint=json.get("oidcDiscoveryEndpoint", ""), - user_info_map=UserInfoMap.from_json(json.get("userInfoMap", {})), + token_endpoint=json.get("tokenEndpoint", None), + token_endpoint_body_params=json.get("tokenEndpointBodyParams", None), + user_info_endpoint=json.get("userInfoEndpoint", None), + user_info_endpoint_query_params=json.get( + "userInfoEndpointQueryParams", None + ), + user_info_endpoint_headers=json.get("userInfoEndpointHeaders", None), + jwks_uri=json.get("jwksURI", None), + oidc_discovery_endpoint=json.get("oidcDiscoveryEndpoint", None), + user_info_map=UserInfoMap.from_json(json.get("userInfoMap", None)), require_email=json.get("requireEmail", None), validate_id_token_payload=None, generate_fake_email=None, diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index 9b6a355de..4d65ef894 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -437,11 +437,11 @@ async def get_user_info( def NewProvider( - input: ProviderInput, # pylint: disable=redefined-builtin + input_: ProviderInput, base_class: Callable[[ProviderConfig], Provider] = GenericProvider, ) -> Provider: - provider_instance = base_class(input.config) - if input.override is not None: - provider_instance = input.override(provider_instance) + provider_instance = base_class(input_.config) + if input_.override is not None: + provider_instance = input_.override(provider_instance) return provider_instance diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index a4310c4c1..f01694e30 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -81,7 +81,10 @@ ClaimValidationError, InvalidClaimsError, ) -from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.thirdparty.provider import ( + Provider, + RedirectUriInfo, +) from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EPAPIOptions, ) @@ -131,7 +134,10 @@ SessionClaimValidator, SessionContainer, ) -from supertokens_python.recipe.thirdparty import ProviderConfig, ThirdPartyRecipe +from supertokens_python.recipe.thirdparty import ( + ProviderConfig, + ThirdPartyRecipe, +) from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, ManuallyCreateOrUpdateUserOkResult, @@ -1176,7 +1182,7 @@ def setup_tenant(): raise Exception("Should never come here") tenant_id = body["tenantId"] login_methods = body["loginMethods"] - core_config = body["coreConfig"] + core_config = "coreConfig" in body and body["coreConfig"] or {} first_factors: List[str] = [] if login_methods.get("emailPassword", {}).get("enabled") == True: @@ -1196,15 +1202,9 @@ def setup_tenant(): if login_methods.get("thirdParty", {}).get("providers") is not None: for provider in login_methods["thirdParty"]["providers"]: - if ( - len(provider) > 1 - ): # TODO: remove this once all tests pass, this is just for making sure we pass the right stuff into ProviderConfig - raise Exception("Pass more stuff into ProviderConfig:" + str(provider)) create_or_update_third_party_config( tenant_id, - config=ProviderConfig( - third_party_id=provider["id"], - ), + config=ProviderConfig.from_json(provider), ) return jsonify({"status": "OK", "createdNew": core_resp.created_new}) From 51b5a2752abfff18edff81862aa64c088a3a4a82 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 2 Oct 2024 15:24:49 +0530 Subject: [PATCH 077/126] fixes more tests --- tests/auth-react/flask-server/app.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index f01694e30..3ee60c66d 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -1290,6 +1290,10 @@ def before_each(): global enabled_recipes global mfa_info global latest_url_with_token + global contact_method + global flow_type + contact_method = "EMAIL_OR_PHONE" + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" latest_url_with_token = "" code_store = dict() accountlinking_config = {} From b0b2a5ce6784693bcc8811368e5450af24736ded Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 3 Oct 2024 15:41:25 +0530 Subject: [PATCH 078/126] fixes fastapi auth-react server --- .vscode/launch.json | 12 + tests/auth-react/fastapi-server/app.py | 776 +++++++++++++++++++++---- 2 files changed, 689 insertions(+), 99 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 9c01eaf88..58eed6c40 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -30,6 +30,18 @@ "FLASK_DEBUG": "1" }, "jinja": true + }, + { + "name": "Python: FastAPI, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/fastapi-server/app.py", + "args": [ + "--port", + "8083" + ], + "cwd": "${workspaceFolder}/tests/auth-react/fastapi-server", + "jinja": true } ] } \ No newline at end of file diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 3ba30b4aa..9d6869626 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -13,7 +13,7 @@ # under the License. import os import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union import uvicorn # type: ignore from dotenv import load_dotenv @@ -26,12 +26,14 @@ from starlette.responses import Response from starlette.types import ASGIApp from typing_extensions import Literal -from supertokens_python.recipe import multitenancy +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe import multifactorauth, multitenancy, totp from supertokens_python import ( InputAppInfo, Supertokens, SupertokensConfig, + convert_to_recipe_user_id, get_all_cors_headers, init, ) @@ -45,10 +47,18 @@ thirdparty, userroles, ) +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EmailPasswordAPIInterface, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, ) from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EPAPIOptions, @@ -72,18 +82,51 @@ APIOptions as EVAPIOptions, ) from supertokens_python.recipe.jwt import JWTRecipe +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, PasswordlessRecipe, ) +from supertokens_python.recipe.passwordless.asyncio import update_user from supertokens_python.recipe.passwordless.interfaces import ( APIInterface as PasswordlessAPIInterface, + PhoneNumberChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.framework.fastapi import verify_session from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, @@ -91,13 +134,19 @@ from supertokens_python.recipe.session.interfaces import APIOptions as SAPIOptions from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.thirdparty import ( + ProviderConfig, ThirdPartyRecipe, ) +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import ( PermissionClaim, @@ -108,8 +157,13 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import AccountInfo, GeneralErrorResponse -from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.types import ( + AccountInfo, + GeneralErrorResponse, + RecipeUserId, + User, +) +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.asyncio import delete_user load_dotenv() @@ -119,6 +173,14 @@ os.environ.setdefault("SUPERTOKENS_ENV", "testing") code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None class CustomPlessEmailService( @@ -145,9 +207,7 @@ async def send_email( code_store[template_vars.pre_auth_session_id] = codes -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -225,7 +285,7 @@ def get_website_domain(): return "http://localhost:" + get_website_port() -latest_url_with_token = None +latest_url_with_token = "" async def validate_age(value: Any, tenant_id: str): @@ -265,75 +325,48 @@ async def get_user_info( # pylint: disable=no-self-use return oi -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + global contact_method + global flow_type + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() EmailVerificationRecipe.reset() SessionRecipe.reset() ThirdPartyRecipe.reset() - EmailVerificationRecipe.reset() EmailPasswordRecipe.reset() + EmailVerificationRecipe.reset() DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() - - providers_list: List[thirdparty.ProviderInput] = [ - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="google", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GOOGLE_CLIENT_ID"], - client_secret=os.environ["GOOGLE_CLIENT_SECRET"], - ), - ], - ), - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="github", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GITHUB_CLIENT_ID"], - client_secret=os.environ["GITHUB_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="facebook", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["FACEBOOK_CLIENT_ID"], - client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="auth0", - name="Auth0", - authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", - authorization_endpoint_query_params={"scope": "openid profile"}, - token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["AUTH0_CLIENT_ID"], - client_secret=os.environ["AUTH0_CLIENT_SECRET"], - ) - ], - ), - override=auth0_provider_override, - ), - ] + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -358,7 +391,11 @@ async def email_verify_post( if is_general_error: return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( - token, session, tenant_id, api_options, user_context + token, + session, + tenant_id, + api_options, + user_context, ) async def generate_email_verify_token_post( @@ -374,9 +411,7 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - session, - api_options, - user_context, + session, api_options, user_context ) original_implementation_email_verification.email_verify_post = email_verify_post @@ -656,12 +691,87 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + global enabled_providers + if enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in enabled_providers + ] + if contact_method is not None and flow_type is not None: if contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -684,7 +794,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -694,7 +804,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -703,32 +813,243 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + global mfa_info + + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mfa_info: + return mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mfa_info: + if factor_id not in mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mfa_info: + return mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mfa_info: + res.factors.already_setup = mfa_info["alreadySetup"][:] + + if "noContacts" in mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailpassword.EmailDeliveryConfig(CustomEPEmailService()), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), + ), + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + global accountlinking_config + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + global enabled_recipes + if enabled_recipes is not None: + recipe_list = [ + item["init"] for item in recipe_list if item["id"] in enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( @@ -757,20 +1078,272 @@ async def exception_handler(a, b): # type: ignore @app.post("/beforeeach") def before_each(): global code_store + global accountlinking_config + global enabled_providers + global enabled_recipes + global mfa_info + global latest_url_with_token + global contact_method + global flow_type + contact_method = "EMAIL_OR_PHONE" + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + latest_url_with_token = "" code_store = dict() + accountlinking_config = {} + enabled_providers = None + enabled_recipes = None + mfa_info = {} custom_init() return PlainTextResponse("") +@app.post("/changeEmail") +async def change_email(request: Request): + body: Union[dict[str, Any], None] = await request.json() + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JSONResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JSONResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JSONResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + if isinstance(resp, PhoneNumberChangeNotAllowedError): + return JSONResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +@app.post("/setupTenant") +async def setup_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JSONResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +@app.post("/addUserToTenant") +async def add_user_to_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JSONResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JSONResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JSONResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +@app.post("/removeUserFromTenant") +async def remove_user_from_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JSONResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +@app.post("/removeTenant") +async def remove_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JSONResponse({"status": "OK", "didExist": core_resp.did_exist}) + + @app.post("/test/setFlow") async def test_set_flow(request: Request): body = await request.json() + global contact_method + global flow_type contact_method = body["contactMethod"] flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + custom_init() return PlainTextResponse("") +@app.post("/test/setAccountLinkingConfig") +async def test_set_account_linking_config(request: Request): + global accountlinking_config + body = await request.json() + if body is None: + raise Exception("Invalid request body") + accountlinking_config = body + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/setMFAInfo") +async def set_mfa_info(request: Request): + global mfa_info + body = await request.json() + if body is None: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + mfa_info = body + return JSONResponse({"status": "OK"}) + + +@app.post("/addRequiredFactor") +async def add_required_factor( + request: Request, session: SessionContainer = Depends(verify_session()) +): + body = await request.json() + if body is None or "factorId" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session.get_user_id(), body["factorId"] + ) + + return JSONResponse({"status": "OK"}) + + +@app.post("/test/setEnabledRecipes") +async def test_set_enabled_recipes(request: Request): + global enabled_recipes + global enabled_providers + body = await request.json() + if body is None: + raise Exception("Invalid request body") + enabled_recipes = body.get("enabledRecipes") + enabled_providers = body.get("enabledProviders") + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/test/getTOTPCode") +async def test_get_totp_code(request: Request): + from pyotp import TOTP + + body = await request.json() + if body is None or "secret" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JSONResponse({"totp": code}) + + @app.get("/test/getDevice") def test_get_device(request: Request): global code_store @@ -789,6 +1362,11 @@ def test_feature_flags(request: Request): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] return JSONResponse({"available": available}) @@ -922,4 +1500,4 @@ def preflight_response(self, request_headers: Headers) -> Response: ) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=get_api_port()) # type: ignore + uvicorn.run(app, host="0.0.0.0", port=int(get_api_port())) # type: ignore From 7b9d23ca413079f205281c3f54f4dab0c158d9fc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 3 Oct 2024 16:47:38 +0530 Subject: [PATCH 079/126] more test fixes --- .vscode/launch.json | 15 + tests/auth-react/django3x/manage.py | 1 + tests/auth-react/django3x/mysite/settings.py | 2 +- tests/auth-react/django3x/mysite/store.py | 28 +- tests/auth-react/django3x/mysite/utils.py | 406 +++++++++++++++++-- tests/auth-react/django3x/polls/urls.py | 38 +- tests/auth-react/django3x/polls/views.py | 291 ++++++++++++- 7 files changed, 713 insertions(+), 68 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 58eed6c40..32acced28 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -42,6 +42,21 @@ ], "cwd": "${workspaceFolder}/tests/auth-react/fastapi-server", "jinja": true + }, + { + "name": "Python: Django, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/django3x/manage.py", + "args": [ + "runserver", + "0.0.0.0:8083" + ], + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "cwd": "${workspaceFolder}/tests/auth-react/django3x", + "jinja": true } ] } \ No newline at end of file diff --git a/tests/auth-react/django3x/manage.py b/tests/auth-react/django3x/manage.py index be146f802..ef5aa4985 100644 --- a/tests/auth-react/django3x/manage.py +++ b/tests/auth-react/django3x/manage.py @@ -5,6 +5,7 @@ def main(): + os.environ.setdefault("SUPERTOKENS_ENV", "testing") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings") try: from django.core.management import execute_from_command_line diff --git a/tests/auth-react/django3x/mysite/settings.py b/tests/auth-react/django3x/mysite/settings.py index a8f7d3686..6ad791b27 100644 --- a/tests/auth-react/django3x/mysite/settings.py +++ b/tests/auth-react/django3x/mysite/settings.py @@ -30,7 +30,7 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -custom_init(None, None) +custom_init() ALLOWED_HOSTS = ["localhost"] diff --git a/tests/auth-react/django3x/mysite/store.py b/tests/auth-react/django3x/mysite/store.py index 37f0dd2ea..07da199ba 100644 --- a/tests/auth-react/django3x/mysite/store.py +++ b/tests/auth-react/django3x/mysite/store.py @@ -1,18 +1,26 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Literal, Optional, Union -_LATEST_URL_WITH_TOKEN = None +latest_url_with_token = "" def save_url_with_token(url_with_token: str): - global _LATEST_URL_WITH_TOKEN - _LATEST_URL_WITH_TOKEN = url_with_token # type: ignore + global latest_url_with_token + latest_url_with_token = url_with_token def get_url_with_token() -> str: - return _LATEST_URL_WITH_TOKEN # type: ignore + return latest_url_with_token -_CODE_STORE: Dict[str, List[Dict[str, Any]]] = {} +code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None def save_code( @@ -20,8 +28,8 @@ def save_code( url_with_link_code: Union[str, None], user_input_code: Union[str, None], ): - global _CODE_STORE - codes = _CODE_STORE.get(pre_auth_session_id, []) + global code_store + codes = code_store.get(pre_auth_session_id, []) # replace sub string in url_with_link_code if url_with_link_code: url_with_link_code = url_with_link_code.replace( @@ -30,8 +38,8 @@ def save_code( codes.append( {"urlWithLinkCode": url_with_link_code, "userInputCode": user_input_code} ) - _CODE_STORE[pre_auth_session_id] = codes + code_store[pre_auth_session_id] = codes def get_codes(pre_auth_session_id: str) -> List[Dict[str, Any]]: - return _CODE_STORE.get(pre_auth_session_id, []) + return code_store.get(pre_auth_session_id, []) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index 27f988177..c3d03d502 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv from typing_extensions import Literal @@ -9,11 +9,14 @@ from supertokens_python.recipe import ( emailpassword, emailverification, + multifactorauth, passwordless, session, thirdparty, + totp, userroles, ) +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.interfaces import ( @@ -51,6 +54,10 @@ from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, ) @@ -61,12 +68,21 @@ ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import UserRolesRecipe -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, User from .store import save_code, save_url_with_token from supertokens_python.recipe import multitenancy +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe load_dotenv() @@ -112,9 +128,7 @@ async def send_email( ) -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -260,12 +274,34 @@ async def get_user_info( # pylint: disable=no-self-use ] -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + import mysite.store + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() @@ -277,6 +313,8 @@ def custom_init( DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -601,20 +639,94 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation - if contact_method is not None and flow_type is not None: - if contact_method == "PHONE": + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + if mysite.store.enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in mysite.store.enabled_providers + ] + + if mysite.store.contact_method is not None and mysite.store.flow_type is not None: + if mysite.store.contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), - flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + flow_type=mysite.store.flow_type, + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), ) - elif contact_method == "EMAIL": + elif mysite.store.contact_method == "EMAIL": passwordless_init = passwordless.init( contact_config=ContactEmailOnlyConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), @@ -625,11 +737,11 @@ async def resend_code_post( else: passwordless_init = passwordless.init( contact_config=ContactEmailOrPhoneConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -639,7 +751,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -648,34 +760,240 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mysite.store.mfa_info: + return mysite.store.mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mysite.store.mfa_info: + if factor_id not in mysite.store.mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mysite.store.mfa_info: + return mysite.store.mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mysite.store.mfa_info: + res.factors.already_setup = mysite.store.mfa_info["alreadySetup"][:] + + if "noContacts" in mysite.store.mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailverification.EmailDeliveryConfig( - CustomEPEmailService() + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mysite.store.mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **mysite.store.accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + if mysite.store.enabled_recipes is not None: + recipe_list = [ + item["init"] + for item in recipe_list + if item["id"] in mysite.store.enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( diff --git a/tests/auth-react/django3x/polls/urls.py b/tests/auth-react/django3x/polls/urls.py index b96151911..8b3cf999e 100644 --- a/tests/auth-react/django3x/polls/urls.py +++ b/tests/auth-react/django3x/polls/urls.py @@ -8,23 +8,49 @@ path("ping", views.ping, name="ping"), path("sessionInfo", views.session_info, name="sessionInfo"), path("token", views.token, name="token"), - path("test/setFlow", views.test_set_flow, name="setFlow"), - path("test/getDevice", views.test_get_device, name="getDevice"), - path("test/featureFlags", views.test_feature_flags, name="featureFlags"), - path("beforeeach", views.before_each, name="beforeeach"), + path("changeEmail", views.change_email, name="changeEmail"), # type: ignore + path("setupTenant", views.setup_tenant, name="setupTenant"), # type: ignore + path("removeTenant", views.remove_tenant, name="removeTenant"), # type: ignore + path( + "removeUserFromTenant", + views.remove_user_from_tenant, # type: ignore + name="removeUserFromTenant", + ), # type: ignore + path("addUserToTenant", views.add_user_to_tenant, name="addUserToTenant"), # type: ignore + path("test/setFlow", views.test_set_flow, name="setFlow"), # type: ignore + path( + "test/setAccountLinkingConfig", + views.test_set_account_linking_config, # type: ignore + name="setAccountLinkingConfig", + ), # type: ignore + path("setMFAInfo", views.set_mfa_info, name="setMfaInfo"), # type: ignore + path( + "addRequiredFactor", + views.add_required_factor, # type: ignore + name="addRequiredFactor", + ), # type: ignore + path( + "test/setEnabledRecipes", + views.test_set_enabled_recipes, # type: ignore + name="setEnabledRecipes", + ), + path("test/getTOTPCode", views.test_get_totp_code, name="getTotpCode"), # type: ignore + path("test/getDevice", views.test_get_device, name="getDevice"), # type: ignore + path("test/featureFlags", views.test_feature_flags, name="featureFlags"), # type: ignore + path("beforeeach", views.before_each, name="beforeeach"), # type: ignore ] mode = os.environ.get("APP_MODE", "asgi") if mode == "asgi": - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.unverify_email_api, name="unverifyEmail"), # type: ignore path("setRole", views.set_role_api, name="setRole"), # type: ignore path("checkRole", views.check_role_api, name="checkRole"), # type: ignore path("deleteUser", views.delete_user, name="deleteUser"), # type: ignore ] else: - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.sync_unverify_email_api, name="unverifyEmail"), path("setRole", views.sync_set_role_api, name="setRole"), path("checkRole", views.sync_check_role_api, name="checkRole"), diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index 550c43ce2..c1da33fcc 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -15,16 +15,57 @@ import os from typing import List, Dict, Any -from django.conf import settings from django.http import HttpRequest, HttpResponse, JsonResponse from mysite.store import get_codes, get_url_with_token from mysite.utils import custom_init +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) from supertokens_python.recipe.userroles import UserRoleClaim, PermissionClaim -from supertokens_python.types import AccountInfo +from supertokens_python.types import AccountInfo, RecipeUserId +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.passwordless.asyncio import update_user +from supertokens_python.recipe.passwordless.interfaces import ( + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, +) mode = os.environ.get("APP_MODE", "asgi") @@ -179,16 +220,247 @@ def test_get_device(request: HttpRequest): return JsonResponse({"preAuthSessionId": pre_auth_session_id, "codes": codes}) -def test_set_flow(request: HttpRequest): +async def change_email(request: HttpRequest): body = json.loads(request.body) - contact_method = body["contactMethod"] - flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JsonResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JsonResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +async def setup_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JsonResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +async def add_user_to_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JsonResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JsonResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JsonResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +async def remove_user_from_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JsonResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +async def remove_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JsonResponse({"status": "OK", "didExist": core_resp.did_exist}) + + +async def test_set_flow(request: HttpRequest): + body = json.loads(request.body) + import mysite.store + + mysite.store.contact_method = body["contactMethod"] + mysite.store.flow_type = body["flowType"] + custom_init() + return HttpResponse("") + + +async def test_set_account_linking_config(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.accountlinking_config = body + custom_init() + return HttpResponse("") + + +async def set_mfa_info(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + mysite.store.mfa_info = body + return JsonResponse({"status": "OK"}) + + +@verify_session() +async def add_required_factor(request: HttpRequest): + session_: SessionContainer = request.supertokens # type: ignore + body = json.loads(request.body) + if body is None or "factorId" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session_.get_user_id(), body["factorId"] + ) + + return JsonResponse({"status": "OK"}) + + +def test_set_enabled_recipes(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.enabled_recipes = body.get("enabledRecipes") + mysite.store.enabled_providers = body.get("enabledProviders") + custom_init() return HttpResponse("") +def test_get_totp_code(request: HttpRequest): + from pyotp import TOTP + + body = json.loads(request.body) + if body is None or "secret" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JsonResponse({"totp": code}) + + def before_each(request: HttpRequest): - setattr(settings, "CODE_STORE", dict()) + import mysite.store + + mysite.store.code_store = dict() custom_init() return HttpResponse("") @@ -202,6 +474,11 @@ def test_feature_flags(request: HttpRequest): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] } ) From bb56f1e755e3e1ce37d75d1636e30b7df3b9fe5e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 3 Oct 2024 17:03:02 +0530 Subject: [PATCH 080/126] more fixes --- tests/frontendIntegration/django2x/polls/views.py | 2 +- tests/frontendIntegration/django3x/polls/views.py | 2 +- tests/frontendIntegration/drf_async/polls/views.py | 2 +- tests/frontendIntegration/drf_sync/polls/views.py | 2 +- tests/frontendIntegration/fastapi-server/app.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index e2472e21c..eda19409a 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -405,7 +405,7 @@ def login(request: HttpRequest): if request.method == "POST": user_id = json.loads(request.body)["userId"] - session_ = create_new_session(request, "public", user_id) + session_ = create_new_session(request, "public", RecipeUserId(user_id)) return HttpResponse(session_.get_user_id()) else: return send_options_api_response() diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 0ec1fa4d2..1e182fb76 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -408,7 +408,7 @@ async def login(request: HttpRequest): if request.method == "POST": user_id = json.loads(request.body)["userId"] - session_ = await create_new_session(request, "public", user_id) + session_ = await create_new_session(request, "public", RecipeUserId(user_id)) return HttpResponse(session_.get_user_id()) else: return send_options_api_response() diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index 757ae922f..e13848aef 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -434,7 +434,7 @@ async def login(request: Request): # type: ignore if request.method == "POST": # type: ignore user_id = request.data["userId"] # type: ignore - session_ = await create_new_session(request, "public", user_id) # type: ignore + session_ = await create_new_session(request, "public", RecipeUserId(user_id)) # type: ignore return Response(session_.get_user_id()) # type: ignore else: return send_options_api_response() # type: ignore diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index bbcbbb5d3..cd776104b 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -434,7 +434,7 @@ def login(request: Request): # type: ignore if request.method == "POST": # type: ignore user_id = request.data["userId"] # type: ignore - session_ = create_new_session(request, "public", user_id) # type: ignore + session_ = create_new_session(request, "public", RecipeUserId(user_id)) # type: ignore return Response(session_.get_user_id()) # type: ignore else: return send_options_api_response() # type: ignore diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index be1d31f96..addaf27ab 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -273,7 +273,7 @@ def login_options(): @app.post("/login") async def login(request: Request): user_id = (await request.json())["userId"] - _session = await create_new_session(request, "public", user_id) + _session = await create_new_session(request, "public", RecipeUserId(user_id)) return PlainTextResponse(content=_session.get_user_id()) From fc8927b873848ef528749cd471a3b01927c02ede Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 3 Oct 2024 17:08:04 +0530 Subject: [PATCH 081/126] fixes --- tests/auth-react/django3x/polls/views.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index c1da33fcc..c531c9ddc 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -460,7 +460,14 @@ def test_get_totp_code(request: HttpRequest): def before_each(request: HttpRequest): import mysite.store + mysite.store.contact_method = "EMAIL_OR_PHONE" + mysite.store.flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + mysite.store.latest_url_with_token = "" mysite.store.code_store = dict() + mysite.store.accountlinking_config = {} + mysite.store.enabled_providers = None + mysite.store.enabled_recipes = None + mysite.store.mfa_info = {} custom_init() return HttpResponse("") From c7b77b6c20e578e6ff3a49f9a7cf9b849750e886 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 4 Oct 2024 00:03:06 +0530 Subject: [PATCH 082/126] fixes --- .vscode/launch.json | 15 +++++++++++++++ dev-requirements.txt | 3 ++- setup.py | 1 + 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 32acced28..f98c12000 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,6 +16,21 @@ }, "jinja": true }, + { + "name": "Python: FastAPI, supertokens-website tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/frontendIntegration/fastapi-server/app.py", + "args": [ + "--port", + "8080" + ], + "cwd": "${workspaceFolder}/tests/frontendIntegration/fastapi-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true + }, { "name": "Python: Flask, supertokens-auth-react tests", "type": "python", diff --git a/dev-requirements.txt b/dev-requirements.txt index 0694d481a..53cb8e21f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -85,4 +85,5 @@ uvicorn==0.18.2 Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 -pyotp==2.9.0 \ No newline at end of file +pyotp==2.9.0 +aiofiles==24.1.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 1b3b89397..ff4cf74c1 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "uvicorn==0.18.2", "python-dotenv==0.19.2", "pyotp==2.9.0", + "aiofiles==24.1.0", ] ), "flask": ( From 75ec7371462738a4643126b56526f268a92e13f4 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 7 Oct 2024 11:46:39 +0530 Subject: [PATCH 083/126] solves https://github.com/supertokens/supertokens-node/issues/657 --- supertokens_python/recipe/dashboard/api/users_get.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/recipe/dashboard/api/users_get.py b/supertokens_python/recipe/dashboard/api/users_get.py index 817f51e8c..5e9c0ae7b 100644 --- a/supertokens_python/recipe/dashboard/api/users_get.py +++ b/supertokens_python/recipe/dashboard/api/users_get.py @@ -79,7 +79,7 @@ async def handle_users_get_api( async def get_user_metadata_and_update_user(user_idx: int) -> None: user = users_response.users[user_idx] - user_metadata = await get_user_metadata(user.id, user_context) + user_metadata = await get_user_metadata(user.id) first_name = user_metadata.metadata.get("first_name") last_name = user_metadata.metadata.get("last_name") From 0708bc48dba9ae027143318dfcdf24f41becb127 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 7 Oct 2024 17:32:15 +0530 Subject: [PATCH 084/126] fixes a small issue --- supertokens_python/recipe/emailverification/recipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 1c6e7c1a7..0fc977b5b 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -414,7 +414,7 @@ async def update_session_if_required_post_email_verification( await revoke_all_sessions_for_user( recipe_user_id_whose_email_got_verified.get_as_string(), False, - session.get_tenant_id(), + None, user_context, ) From 0128820aff295c7bbc49390c783c302e5db9e8d7 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Tue, 8 Oct 2024 12:16:43 +0530 Subject: [PATCH 085/126] fixes script for cicd --- .circleci/setupAndTestBackendSDKWithFreeCore.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.circleci/setupAndTestBackendSDKWithFreeCore.sh b/.circleci/setupAndTestBackendSDKWithFreeCore.sh index c301a0c18..d774003dd 100755 --- a/.circleci/setupAndTestBackendSDKWithFreeCore.sh +++ b/.circleci/setupAndTestBackendSDKWithFreeCore.sh @@ -56,7 +56,6 @@ ST_CONNECTION_URI=http://localhost:8081 # start test-server pushd tests/test-server -sh setup-for-test.sh SUPERTOKENS_ENV=testing API_PORT=$API_PORT ST_CONNECTION_URI=$ST_CONNECTION_URI python3 app.py & popd From 90d8d52926d75e70aa7cfd457ddc44edd83cf9ac Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 11:48:01 +0530 Subject: [PATCH 086/126] small changes --- .vscode/launch.json | 15 ++++++++++++++- tests/test-server/app.py | 6 ++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index f98c12000..19251e616 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -72,6 +72,19 @@ }, "cwd": "${workspaceFolder}/tests/auth-react/django3x", "jinja": true - } + }, + { + "name": "Python: backend-sdk-testing repo", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/test-server/app.py", + "cwd": "${workspaceFolder}/tests/test-server", + "env": { + "SUPERTOKENS_ENV": "testing", + "API_PORT": "3030", + "ST_CONNECTION_URI": "http://localhost:8081" + }, + "console": "integratedTerminal" + }, ] } \ No newline at end of file diff --git a/tests/test-server/app.py b/tests/test-server/app.py index f743a56fe..49c16fee7 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,5 +1,8 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.totp.recipe import TOTPRecipe from utils import init_test_claims from supertokens_python.process_state import ProcessState from supertokens_python.recipe.dashboard.recipe import DashboardRecipe @@ -136,6 +139,9 @@ def st_reset(): DashboardRecipe.reset() PasswordlessRecipe.reset() MultitenancyRecipe.reset() + AccountLinkingRecipe.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def init_st(config): # type: ignore From 1832599883db8b7e40c371e836b1d64ac68d413c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 12:34:16 +0530 Subject: [PATCH 087/126] fixes test server code --- tests/test-server/app.py | 13 +++- tests/test-server/override_logging.py | 3 + tests/test-server/session.py | 101 ++++++++++++++++---------- 3 files changed, 77 insertions(+), 40 deletions(-) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 49c16fee7..7e1eedc79 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify +from supertokens_python.framework import BaseRequest from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe @@ -49,11 +50,21 @@ def default_st_init(): + def origin_func( + request: Optional[BaseRequest] = None, context: Dict[str, Any] = {} + ) -> str: + if request is None: + return "http://localhost:8080" + origin = request.get_header("origin") + if origin is not None: + return origin + return "http://localhost:8080" + init( app_info=InputAppInfo( app_name="SuperTokens", api_domain="http://api.supertokens.io", - website_domain="http://localhost:3000", + origin=origin_func, ), supertokens_config=SupertokensConfig(connection_uri="http://localhost:3567"), framework="flask", diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index a7dab1b76..f57c30570 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -4,6 +4,7 @@ from httpx import Response from supertokens_python.framework.flask.flask_request import FlaskRequest +from supertokens_python.types import RecipeUserId override_logs: List[Dict[str, Any]] = [] @@ -40,5 +41,7 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return "FlaskRequest" if isinstance(data, Response): return "Response" + if isinstance(data, RecipeUserId): + return data.get_as_string() return data diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 1d0574122..2d60c50ac 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,6 +1,11 @@ +from typing import Any from flask import Flask, request, jsonify +from supertokens_python.recipe.session.interfaces import TokenInfo +from supertokens_python.recipe.session.jwt import ( + parse_jwt_without_signature_verification, +) +from supertokens_python.types import RecipeUserId from utils import deserialize_validator -from supertokens_python import async_to_sync_wrapper from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session @@ -22,7 +27,7 @@ def create_new_session_without_request_response(): # type: ignore session_container = session.create_new_session_without_request_response( tenant_id, - user_id, + RecipeUserId(user_id), access_token_payload, session_data_in_database, disable_anti_csrf, @@ -76,29 +81,13 @@ def assert_claims(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - recipe_user_id=data["session"]["recipeUserId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) claim_validators = list(map(deserialize_validator, data["claimValidators"])) user_context = data.get("userContext", {}) try: - async_to_sync_wrapper.sync( - session_container.assert_claims(claim_validators, user_context) - ) + session_container.sync_assert_claims(claim_validators, user_context) return jsonify( { "status": "OK", @@ -135,28 +124,13 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - recipe_user_id=data["session"]["recipeUserId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) + access_token_payload_update = data["accessTokenPayloadUpdate"] user_context = data.get("userContext", {}) - async_to_sync_wrapper.sync( - session_container.merge_into_access_token_payload( - access_token_payload_update, user_context - ) + session_container.sync_merge_into_access_token_payload( + access_token_payload_update, user_context ) return jsonify( @@ -165,6 +139,7 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore "updatedSession": { "sessionHandle": session_container.get_handle(), "userId": session_container.get_user_id(), + "recipeUserId": session_container.get_recipe_user_id().get_as_string(), "tenantId": session_container.get_tenant_id(), "userDataInAccessToken": session_container.get_access_token_payload(), "accessToken": session_container.get_access_token(), @@ -183,3 +158,51 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore }, } ) + + +def convert_session_to_container(data: Any) -> Session: + jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) + jwt_payload = jwt_info.payload + + user_id = jwt_info.version == 2 and jwt_payload["userId"] or jwt_payload["sub"] + session_handle = jwt_payload["sessionHandle"] + + recipe_user_id = RecipeUserId(jwt_payload.get("rsub", user_id)) + anti_csrf_token = jwt_payload.get("antiCsrfToken") + tenant_id = jwt_info.version >= 4 and jwt_payload["tId"] or "public" + + return Session( + recipe_implementation=SessionRecipe.get_instance().recipe_implementation, + config=SessionRecipe.get_instance().config, + access_token=data["session"]["accessToken"], + front_token=data["session"]["frontToken"], + refresh_token=( + TokenInfo( + ( + data["session"]["refreshToken"] + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["token"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["expiry"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["createdTime"] + ), + ) + if "refreshToken" in data["session"] + else None + ), + anti_csrf_token=anti_csrf_token, + session_handle=session_handle, + user_id=user_id, + recipe_user_id=recipe_user_id, + user_data_in_access_token=jwt_payload, + req_res_info=None, # We don't have this information in the input + access_token_updated=False, + tenant_id=tenant_id, + ) From 999c63661746c94e3b2f39cc1cab8c20d4bacc1d Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 13:31:11 +0530 Subject: [PATCH 088/126] fixes test server --- tests/test-server/session.py | 17 ++++++++++++++--- tests/test-server/utils.py | 17 +++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 2d60c50ac..f97492218 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -5,7 +5,7 @@ parse_jwt_without_signature_verification, ) from supertokens_python.types import RecipeUserId -from utils import deserialize_validator +from utils import deserialize_validator, get_max_version from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session @@ -19,7 +19,18 @@ def create_new_session_without_request_response(): # type: ignore return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data.get("tenantId", "public") - user_id = data["userId"] + from supertokens_python import convert_to_recipe_user_id + + fdi_version = request.headers.get("fdi-version") + assert fdi_version is not None + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + # fdi_version <= "1.17" or (fdi_version >= "2.0" and fdi_version < "3.0") + recipe_user_id = convert_to_recipe_user_id(data["userId"]) + else: + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) access_token_payload = data.get("accessTokenPayload", {}) session_data_in_database = data.get("sessionDataInDatabase", {}) disable_anti_csrf = data.get("disableAntiCsrf") @@ -27,7 +38,7 @@ def create_new_session_without_request_response(): # type: ignore session_container = session.create_new_session_without_request_response( tenant_id, - RecipeUserId(user_id), + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 3ccc84a4b..073d006ff 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -39,3 +39,20 @@ def toSnakeCase(camel_case: str) -> str: else: result += char return result + + +def get_max_version(v1: str, v2: str) -> str: + v1_split = v1.split(".") + v2_split = v2.split(".") + max_loop = min(len(v1_split), len(v2_split)) + + for i in range(max_loop): + if int(v1_split[i]) > int(v2_split[i]): + return v1 + if int(v2_split[i]) > int(v1_split[i]): + return v2 + + if len(v1_split) > len(v2_split): + return v1 + + return v2 From e5fe1dac74cfefc78e6edbfcba37e149fd4ce062 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 17:01:01 +0530 Subject: [PATCH 089/126] fixes more code --- .../session/session_request_functions.py | 4 +- tests/test-server/app.py | 19 ++- tests/test-server/test_functions_mapper.py | 148 +++++++++++++++++- 3 files changed, 165 insertions(+), 6 deletions(-) diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index b1e297c2d..7d70cb2ee 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -269,7 +269,9 @@ async def create_new_session_in_request( del final_access_token_payload[prop] for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, recipe_user_id, tenant_id, user_context) + update = await claim.build( + user_id, recipe_user_id, tenant_id, final_access_token_payload, user_context + ) final_access_token_payload.update(update) log_debug_message("createNewSession: Access token payload built") diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 7e1eedc79..549f88058 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -16,7 +16,7 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -from test_functions_mapper import get_func # type: ignore +from test_functions_mapper import get_func, get_override_params, reset_override_params # type: ignore from emailpassword import add_emailpassword_routes from multitenancy import add_multitenancy_routes from session import add_session_routes @@ -137,6 +137,8 @@ def inner(*args, **kwargs): # type: ignore def st_reset(): + override_logging.reset_override_logs() + reset_override_params() ProcessState.get_instance().reset() Supertokens.reset() SessionRecipe.reset() @@ -331,7 +333,15 @@ def inner( user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: if interceptor_func is not None: - return interceptor_func(url, method, headers, params, body, user_context) # type: ignore + resp = interceptor_func(url, method, headers, params, body, user_context) # type: ignore + return { + "url": resp[0], + "method": resp[1], + "headers": resp[2], + "params": resp[3], + "body": resp[4], + "user_context": resp[5], + } return { "url": url, "method": method, @@ -382,7 +392,7 @@ def init_handler(): @app.route("/test/overrideparams", methods=["GET"]) # type: ignore def override_params(): - return jsonify("TODO") + return jsonify(get_override_params().to_json()) @app.route("/test/featureflag", methods=["GET"]) # type: ignore @@ -391,8 +401,9 @@ def feature_flag(): @app.route("/test/resetoverrideparams", methods=["POST"]) # type: ignore -def reset_override_params(): +def reset_override_params_api(): override_logging.reset_override_logs() + reset_override_params() return jsonify({"ok": True}) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 97a417d58..2efe4e177 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -1,4 +1,8 @@ -from typing import Callable +from typing import Callable, List +from typing import Dict, Any, Optional +from supertokens_python.recipe.accountlinking import RecipeLevelUser +from supertokens_python.types import RecipeUserId +from supertokens_python.types import APIResponse, User class Info: @@ -15,3 +19,145 @@ def func(*args): # type: ignore return func # type: ignore raise Exception("Unknown eval string") + + +class OverrideParams(APIResponse): + def __init__( + self, + send_email_to_user_id: Optional[str] = None, + token: Optional[str] = None, + user_post_password_reset: Optional[User] = None, + email_post_password_reset: Optional[str] = None, + send_email_callback_called: Optional[bool] = None, + send_email_to_user_email: Optional[str] = None, + send_email_inputs: Optional[List[str]] = None, + send_sms_inputs: Optional[List[str]] = None, + send_email_to_recipe_user_id: Optional[str] = None, + user_in_callback: Optional[User] = None, + email: Optional[str] = None, + new_account_info_in_callback: Optional[RecipeLevelUser] = None, + primary_user_in_callback: Optional[User] = None, + user_id_in_callback: Optional[str] = None, + recipe_user_id_in_callback: Optional[str] = None, + core_call_count: int = 0, + store: Optional[Any] = None, + ): + self.send_email_to_user_id = send_email_to_user_id + self.token = token + self.user_post_password_reset = user_post_password_reset + self.email_post_password_reset = email_post_password_reset + self.send_email_callback_called = send_email_callback_called + self.send_email_to_user_email = send_email_to_user_email + self.send_email_inputs = send_email_inputs + self.send_sms_inputs = send_sms_inputs + self.send_email_to_recipe_user_id = send_email_to_recipe_user_id + self.user_in_callback = user_in_callback + self.email = email + self.new_account_info_in_callback = new_account_info_in_callback + self.primary_user_in_callback = primary_user_in_callback + self.user_id_in_callback = user_id_in_callback + self.recipe_user_id_in_callback = recipe_user_id_in_callback + self.core_call_count = core_call_count + self.store = store + + def to_json(self) -> Dict[str, Any]: + return { + "sendEmailToUserId": self.send_email_to_user_id, + "token": self.token, + "userPostPasswordReset": ( + self.user_post_password_reset.to_json() + if self.user_post_password_reset is not None + else None + ), + "emailPostPasswordReset": self.email_post_password_reset, + "sendEmailCallbackCalled": self.send_email_callback_called, + "sendEmailToUserEmail": self.send_email_to_user_email, + "sendEmailInputs": self.send_email_inputs, + "sendSmsInputs": self.send_sms_inputs, + "sendEmailToRecipeUserId": self.send_email_to_recipe_user_id, + "userInCallback": ( + self.user_in_callback.to_json() + if self.user_in_callback is not None + else None + ), + "email": self.email, + "newAccountInfoInCallback": self.new_account_info_in_callback, + "primaryUserInCallback": ( + self.primary_user_in_callback.to_json() + if self.primary_user_in_callback is not None + else None + ), + "userIdInCallback": self.user_id_in_callback, + "recipeUserIdInCallback": self.recipe_user_id_in_callback, + "info": { + "coreCallCount": self.core_call_count, + }, + "store": self.store, + } + + +def get_override_params() -> OverrideParams: + return OverrideParams( + send_email_to_user_id=send_email_to_user_id, + token=token, + user_post_password_reset=user_post_password_reset, + email_post_password_reset=email_post_password_reset, + send_email_callback_called=send_email_callback_called, + send_email_to_user_email=send_email_to_user_email, + send_email_inputs=send_email_inputs, + send_sms_inputs=send_sms_inputs, + send_email_to_recipe_user_id=send_email_to_recipe_user_id, + user_in_callback=user_in_callback, + email=email, + new_account_info_in_callback=new_account_info_in_callback, + primary_user_in_callback=( + primary_user_in_callback if primary_user_in_callback else None + ), + user_id_in_callback=user_id_in_callback, + recipe_user_id_in_callback=( + recipe_user_id_in_callback.get_as_string() + if isinstance(recipe_user_id_in_callback, RecipeUserId) + else None + ), + core_call_count=Info.core_call_count, + store=store, + ) + + +def reset_override_params(): + global send_email_to_user_id, token, user_post_password_reset, email_post_password_reset, send_email_callback_called, send_email_to_user_email, send_email_inputs, send_sms_inputs, send_email_to_recipe_user_id, user_in_callback, email, primary_user_in_callback, new_account_info_in_callback, user_id_in_callback, recipe_user_id_in_callback, store + send_email_to_user_id = None + token = None + user_post_password_reset = None + email_post_password_reset = None + send_email_callback_called = False + send_email_to_user_email = None + send_email_inputs = [] + send_sms_inputs = [] + send_email_to_recipe_user_id = None + user_in_callback = None + email = None + primary_user_in_callback = None + new_account_info_in_callback = None + user_id_in_callback = None + recipe_user_id_in_callback = None + store = None + Info.core_call_count = 0 + + +send_email_to_user_id: Optional[str] = None +token: Optional[str] = None +user_post_password_reset: Optional[User] = None +email_post_password_reset: Optional[str] = None +send_email_callback_called: bool = False +send_email_to_user_email: Optional[str] = None +send_email_inputs: List[str] = [] +send_sms_inputs: List[str] = [] +send_email_to_recipe_user_id: Optional[str] = None +user_in_callback: Optional[User] = None +email: Optional[str] = None +primary_user_in_callback: Optional[User] = None +new_account_info_in_callback: Optional[RecipeLevelUser] = None +user_id_in_callback: Optional[str] = None +recipe_user_id_in_callback: Optional[RecipeUserId] = None +store: Optional[str] = None From 9e181ddc5a34a288bbcb46161b26d22df3255893 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 18:22:55 +0530 Subject: [PATCH 090/126] gets more tests to pass --- tests/test-server/app.py | 175 +++++++++++++++------ tests/test-server/test_functions_mapper.py | 114 +++++++++++++- 2 files changed, 241 insertions(+), 48 deletions(-) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 549f88058..c6690d726 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,6 +1,8 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify from supertokens_python.framework import BaseRequest +from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from supertokens_python.recipe import accountlinking from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe @@ -16,7 +18,7 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -from test_functions_mapper import get_func, get_override_params, reset_override_params # type: ignore +from test_functions_mapper import get_func, get_override_params, reset_override_params from emailpassword import add_emailpassword_routes from multitenancy import add_multitenancy_routes from session import add_session_routes @@ -80,48 +82,52 @@ def toCamelCase(snake_case: str) -> str: return components[0] + "".join(x.title() for x in components[1:]) -def create_override(input, member, name): # type: ignore - member_val = getattr(input, member) # type: ignore +def create_override( + oI: Any, functionName: str, name: str, override_name: Optional[str] = None +): + implementation = oI if override_name is None else get_func(override_name)(oI) + originalFunction = getattr(implementation, functionName) - async def override(*args, **kwargs): # type: ignore + async def finalFunction(*args: Any, **kwargs: Any): override_logging.log_override_event( - name + "." + toCamelCase(member), # type: ignore + name + "." + toCamelCase(functionName), "CALL", {"args": args, "kwargs": kwargs}, ) try: - res = await member_val(*args, **kwargs) + res = await originalFunction(*args, **kwargs) override_logging.log_override_event( - name + "." + toCamelCase(member), "RES", res # type: ignore + name + "." + toCamelCase(functionName), "RES", res ) return res except Exception as e: override_logging.log_override_event( - name + "." + toCamelCase(member), "REJ", e # type: ignore + name + "." + toCamelCase(functionName), "REJ", e ) raise e - setattr(input, member, override) # type: ignore + setattr(oI, functionName, finalFunction) def override_builder_with_logging( name: str, override_name: Optional[str] = None ) -> Callable[[T], T]: - def builder(input: T) -> T: - for member in dir(input): - if member.startswith("__"): - continue + def builder(oI: T) -> T: + members = [ + attr + for attr in dir(oI) + if callable(getattr(oI, attr)) and not attr.startswith("__") + ] - member_val = getattr(input, member) - if callable(member_val): - create_override(input, member, name) - return input + for member in members: + create_override(oI, member, name, override_name) + return oI return builder def logging_override_func_sync(name: str, c: Any) -> Any: - def inner(*args, **kwargs): # type: ignore + def inner(*args: Any, **kwargs: Any) -> Any: override_logging.log_override_event( name, "CALL", {"args": args, "kwargs": kwargs} ) @@ -133,7 +139,25 @@ def inner(*args, **kwargs): # type: ignore override_logging.log_override_event(name, "REJ", e) raise e - return inner # type: ignore + return inner + + +def callback_with_log( + name: str, override_name: Optional[str], default_value: Any = None +) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + if override_name: + impl = get_func(override_name) + else: + + async def default_func(*args: Any, **kwargs: Any) -> Any: + return default_value + + impl = default_func + + return logging_override_func_sync(name, impl)(*args, **kwargs) + + return wrapper def st_reset(): @@ -157,16 +181,16 @@ def st_reset(): MultiFactorAuthRecipe.reset() -def init_st(config): # type: ignore +def init_st(config: Dict[str, Any]): st_reset() override_logging.reset_override_logs() recipe_list: List[Callable[[AppInfo], RecipeModule]] = [] - for recipe_config in config.get("recipeList", []): # type: ignore - recipe_id = recipe_config.get("recipeId") # type: ignore + for recipe_config in config.get("recipeList", []): + recipe_id = recipe_config.get("recipeId") if recipe_id == "emailpassword": sign_up_feature_input = None - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) if "signUpFeature" in recipe_config_json: sign_up_feature = recipe_config_json["signUpFeature"] if "formFields" in sign_up_feature: @@ -178,11 +202,29 @@ def init_st(config): # type: ignore ) recipe_list.append( - emailpassword.init(sign_up_feature=sign_up_feature_input) + emailpassword.init( + sign_up_feature=sign_up_feature_input, + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "EmailPassword.emailDelivery.override", + config.get("emailDelivery", {}).get("override", None), + ) + ), + override=emailpassword.InputOverrideConfig( + apis=override_builder_with_logging( + "EmailPassword.override.apis", + config.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "EmailPassword.override.functions", + config.get("override", {}).get("functions", None), + ), + ), + ) ) elif recipe_id == "session": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( session.init( cookie_secure=recipe_config_json.get("cookieSecure"), @@ -202,11 +244,45 @@ def init_st(config): # type: ignore use_dynamic_access_token_signing_key=recipe_config_json.get( "useDynamicAccessTokenSigningKey" ), + override=session.InputOverrideConfig( + apis=override_builder_with_logging( + "Session.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "Session.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), + ) + ) + elif recipe_id == "accountlinking": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + accountlinking.init( + should_do_automatic_account_linking=callback_with_log( + "AccountLinking.shouldDoAutomaticAccountLinking", + recipe_config_json.get("shouldDoAutomaticAccountLinking"), + accountlinking.ShouldNotAutomaticallyLink(), + ), + on_account_linked=callback_with_log( + "AccountLinking.onAccountLinked", + recipe_config_json.get("on_account_linked"), + ), + override=accountlinking.InputOverrideConfig( + functions=override_builder_with_logging( + "AccountLinking.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), ) ) - elif recipe_id == "thirdparty": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) providers: List[thirdparty.ProviderInput] = [] if "signInAndUpFeature" in recipe_config_json: sign_in_up_feature = recipe_config_json["signInAndUpFeature"] @@ -298,21 +374,26 @@ def init_st(config): # type: ignore ) elif recipe_id == "emailverification": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) ev_config: Dict[str, Any] = {"mode": "OPTIONAL"} if "mode" in recipe_config_json: ev_config["mode"] = recipe_config_json["mode"] - override_functions = override_builder_with_logging("EmailVerification.override.functions") # type: ignore + override_functions = override_builder_with_logging( + "EmailVerification.override.functions", + config.get("override", {}).get("functions", None), + ) ev_config["override"] = emailverification.InputOverrideConfig( - functions=override_functions # type: ignore + functions=override_functions ) recipe_list.append(emailverification.init(**ev_config)) - interceptor_func = None # type: ignore - if config.get("supertokens", {}).get("networkInterceptor") is not None: # type: ignore - interceptor_func = get_func(config.get("supertokens", {}).get("networkInterceptor")) # type: ignore + interceptor_func = None + if config.get("supertokens", {}).get("networkInterceptor") is not None: + interceptor_func = get_func( + config.get("supertokens", {}).get("networkInterceptor") + ) def network_interceptor_func( url: str, @@ -333,7 +414,9 @@ def inner( user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: if interceptor_func is not None: - resp = interceptor_func(url, method, headers, params, body, user_context) # type: ignore + resp = interceptor_func( + url, method, headers, params, body, user_context + ) return { "url": resp[0], "method": resp[1], @@ -351,7 +434,9 @@ def inner( "user_context": user_context, } - res = logging_override_func_sync("networkInterceptor", inner)(url, method, headers, params, body, user_context) # type: ignore + res = logging_override_func_sync("networkInterceptor", inner)( + url, method, headers, params, body, user_context + ) return ( res.get("url"), res.get("method"), @@ -362,12 +447,12 @@ def inner( init( app_info=InputAppInfo( - app_name=config["appInfo"]["appName"], # type: ignore - api_domain=config["appInfo"]["apiDomain"], # type: ignore - website_domain=config["appInfo"]["websiteDomain"], # type: ignore + app_name=config["appInfo"]["appName"], + api_domain=config["appInfo"]["apiDomain"], + website_domain=config["appInfo"]["websiteDomain"], ), supertokens_config=SupertokensConfig( - connection_uri=config["supertokens"]["connectionURI"], # type: ignore + connection_uri=config["supertokens"]["connectionURI"], network_interceptor=network_interceptor_func, ), framework="flask", @@ -383,7 +468,9 @@ def ping(): @app.route("/test/init", methods=["POST"]) # type: ignore def init_handler(): - config = request.json.get("config") # type: ignore + if request.json is None: + return jsonify({"error": "No config provided"}), 400 + config = request.json.get("config") if config: init_st(json.loads(config)) return jsonify({"ok": True}) @@ -417,9 +504,9 @@ def mock_external_api(): return jsonify({"ok": True}) -# @app.route("/create", methods=["POST"]) # type: ignore +# @app.route("/create", methods=["POST"]) # def create_session(): -# recipe_user_id = request.json.get("recipeUserId") # type: ignore +# recipe_user_id = request.json.get("recipeUserId") # session = session.create_new_session(request, "public", recipe_user_id) # return jsonify({"status": "OK"}) @@ -434,7 +521,7 @@ def get_session(): ) -# @app.route("/refreshsession", methods=["POST"]) # type: ignore +# @app.route("/refreshsession", methods=["POST"]) # def refresh_session(): # session: SessionContainer = session.refresh_session(request) # return jsonify( @@ -449,7 +536,7 @@ def verify_session_route(): @app.errorhandler(404) -def not_found(error): # type: ignore +def not_found(error: Any) -> Any: return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404 diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 2efe4e177..83b4be17f 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -1,7 +1,12 @@ -from typing import Callable, List +from typing import Callable, List, Union from typing import Dict, Any, Optional -from supertokens_python.recipe.accountlinking import RecipeLevelUser -from supertokens_python.types import RecipeUserId +from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.recipe.accountlinking import ( + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) +from supertokens_python.types import AccountInfo, RecipeUserId from supertokens_python.types import APIResponse, User @@ -9,7 +14,7 @@ class Info: core_call_count = 0 -def get_func(eval_str: str) -> Callable: # type: ignore +def get_func(eval_str: str) -> Callable[..., Any]: if eval_str.startswith("supertokens.init.supertokens.networkInterceptor"): def func(*args): # type: ignore @@ -18,6 +23,107 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): + + async def func( + i: Any, l: Any, o: Any, u: Any, a: Any + ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + if ( + "()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!1})" + in eval_str + ): + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,u,a)=>a.DO_LINK?{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}:{shouldAutomaticallyLink:!1}" + in eval_str + ): + if a.get("DO_LINK"): + return ShouldAutomaticallyLink(should_require_verification=True) + return ShouldNotAutomaticallyLink() + + if ( + "(i,l,o,u,a)=>a.DO_NOT_LINK?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,u,a)=>a.DO_NOT_LINK?{shouldAutomaticallyLink:!1}:a.DO_LINK_WITHOUT_VERIFICATION?{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if a.get("DO_LINK_WITHOUT_VERIFICATION"): + return ShouldAutomaticallyLink(should_require_verification=False) + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + '(i,l,o,a,e)=>e.DO_NOT_LINK||"test2@example.com"===i.email&&void 0===l?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}' + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if i.get("email") == "test2@example.com" and l is None: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,d,t)=>t.DO_NOT_LINK||void 0!==l&&l.id===o.getUserId()?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if l is not None and l.get("id") == o.getUserId(): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,d,t)=>t.DO_NOT_LINK||void 0!==l&&l.id===o.getUserId()?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if l is not None and l.get("id") == o.getUserId(): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + '(i,l,o,a,e)=>e.DO_NOT_LINK||"test2@example.com"===i.email&&void 0===l?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}' + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if i.get("email") == "test2@example.com" and l is None: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + 'async(i,e)=>{if("emailpassword"===i.recipeId){if(!((await supertokens.listUsersByAccountInfo("public",{email:i.email})).length>1))return{shouldAutomaticallyLink:!1}}return{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}}' + in eval_str + ): + if i.get("recipeId") == "emailpassword": + users = await list_users_by_account_info( + "public", AccountInfo(email=i.get("email")) + ) + if len(users) <= 1: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + "async()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!0})" + in eval_str + or "()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!0})" + in eval_str + ): + return ShouldAutomaticallyLink(should_require_verification=True) + + return ShouldNotAutomaticallyLink() + + return func + raise Exception("Unknown eval string") From 364ccaa4d0ac2f8d07bf6f21475ace4dc1cdee4e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 20:27:49 +0530 Subject: [PATCH 091/126] gets more tests to pass --- tests/test-server/__init__.py | 0 tests/test-server/app.py | 48 ++++++++++++---- tests/test-server/emailpassword.py | 55 +++++++++++++----- tests/test-server/emailverification.py | 67 ++++++++++++++++++++++ tests/test-server/session.py | 9 ++- tests/test-server/test_functions_mapper.py | 2 +- tests/test-server/utils.py | 30 ++++++++++ 7 files changed, 182 insertions(+), 29 deletions(-) create mode 100644 tests/test-server/__init__.py create mode 100644 tests/test-server/emailverification.py diff --git a/tests/test-server/__init__.py b/tests/test-server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test-server/app.py b/tests/test-server/app.py index c6690d726..87a63489b 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -2,11 +2,11 @@ from flask import Flask, request, jsonify from supertokens_python.framework import BaseRequest from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig -from supertokens_python.recipe import accountlinking +from supertokens_python.recipe import accountlinking, multifactorauth from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe -from utils import init_test_claims +from utils import init_test_claims # pylint: disable=import-error from supertokens_python.process_state import ProcessState from supertokens_python.recipe.dashboard.recipe import DashboardRecipe from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe @@ -18,10 +18,17 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -from test_functions_mapper import get_func, get_override_params, reset_override_params -from emailpassword import add_emailpassword_routes -from multitenancy import add_multitenancy_routes -from session import add_session_routes +from test_functions_mapper import ( # pylint: disable=import-error + get_func, + get_override_params, + reset_override_params, +) # pylint: disable=import-error +from emailpassword import add_emailpassword_routes # pylint: disable=import-error +from multitenancy import add_multitenancy_routes # pylint: disable=import-error +from emailverification import ( + add_emailverification_routes, +) # pylint: disable=import-error +from session import add_session_routes # pylint: disable=import-error from supertokens_python import ( AppInfo, Supertokens, @@ -52,8 +59,11 @@ def default_st_init(): - def origin_func( - request: Optional[BaseRequest] = None, context: Dict[str, Any] = {} + def origin_func( # pylint: disable=unused-argument, dangerous-default-value + request: Optional[BaseRequest] = None, + context: Dict[ # pylint: disable=unused-argument, dangerous-default-value + str, Any + ] = {}, # pylint: disable=unused-argument, dangerous-default-value ) -> str: if request is None: return "http://localhost:8080" @@ -150,7 +160,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: impl = get_func(override_name) else: - async def default_func(*args: Any, **kwargs: Any) -> Any: + async def default_func( # pylint: disable=unused-argument + *args: Any, **kwargs: Any # pylint: disable=unused-argument + ) -> Any: # pylint: disable=unused-argument return default_value impl = default_func @@ -388,6 +400,21 @@ def init_st(config: Dict[str, Any]): functions=override_functions ) recipe_list.append(emailverification.init(**ev_config)) + elif recipe_id == "multifactorauth": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + multifactorauth.init( + first_factors=recipe_config_json.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_builder_with_logging( + "MultifactorAuth.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), + ) + ) interceptor_func = None if config.get("supertokens", {}).get("networkInterceptor") is not None: @@ -536,13 +563,14 @@ def verify_session_route(): @app.errorhandler(404) -def not_found(error: Any) -> Any: +def not_found(error: Any) -> Any: # pylint: disable=unused-argument return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404 add_emailpassword_routes(app) add_multitenancy_routes(app) add_session_routes(app) +add_emailverification_routes(app) init_test_claims() diff --git a/tests/test-server/emailpassword.py b/tests/test-server/emailpassword.py index 3b6f73b4b..9691fa21c 100644 --- a/tests/test-server/emailpassword.py +++ b/tests/test-server/emailpassword.py @@ -6,8 +6,14 @@ UnknownUserIdError, UpdateEmailOrPasswordEmailChangeNotAllowedError, UpdateEmailOrPasswordOkResult, + WrongCredentialsError, ) import supertokens_python.recipe.emailpassword.syncio as emailpassword +from session import convert_session_to_container # pylint: disable=import-error +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error def add_emailpassword_routes(app: Flask): @@ -20,23 +26,35 @@ def emailpassword_signup(): # type: ignore email = data["email"] password = data["password"] user_context = data.get("userContext") + session = ( + convert_session_to_container(data["session"]) if "session" in data else None + ) - response = emailpassword.sign_up(tenant_id, email, password, user_context) + response = emailpassword.sign_up( + tenant_id, email, password, session, user_context + ) if isinstance(response, SignUpOkResult): return jsonify( { "status": "OK", - "user": { - "id": response.user.id, - "email": response.user.emails[0], - "timeJoined": response.user.time_joined, - "tenantIds": response.user.tenant_ids, - }, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), } ) - else: + elif isinstance(response, EmailAlreadyExistsError): return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) @app.route("/test/emailpassword/signin", methods=["POST"]) # type: ignore def emailpassword_signin(): # type: ignore @@ -55,16 +73,23 @@ def emailpassword_signin(): # type: ignore return jsonify( { "status": "OK", - "user": { - "id": response.user.id, - "email": response.user.emails[0], - "timeJoined": response.user.time_joined, - "tenantIds": response.user.tenant_ids, - }, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), } ) - else: + elif isinstance(response, WrongCredentialsError): return jsonify({"status": "WRONG_CREDENTIALS_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) @app.route("/test/emailpassword/createresetpasswordlink", methods=["POST"]) # type: ignore def emailpassword_create_reset_password_link(): # type: ignore diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py new file mode 100644 index 000000000..9707ec247 --- /dev/null +++ b/tests/test-server/emailverification.py @@ -0,0 +1,67 @@ +from flask import Flask, request, jsonify + +from supertokens_python.recipe.emailverification.interfaces import ( + CreateEmailVerificationTokenOkResult, + VerifyEmailUsingTokenOkResult, +) +from supertokens_python.recipe.emailverification.syncio import ( + create_email_verification_token, +) + + +def add_emailverification_routes(app: Flask): + @app.route("/test/emailverification/createemailverificationtoken", methods=["POST"]) # type: ignore + def f(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) + tenant_id = data.get("tenantId", "public") + email = None if "email" not in data else data["email"] + user_context = data.get("userContext") + + response = create_email_verification_token( + tenant_id, recipe_user_id, email, user_context + ) + + if isinstance(response, CreateEmailVerificationTokenOkResult): + return jsonify({"status": "OK", "token": response.token}) + else: + return jsonify({"status": "EMAIL_ALREADY_VERIFIED_ERROR"}) + + @app.route("/test/emailverification/verifyemailusingtoken", methods=["POST"]) # type: ignore + def f2(): # type: ignore + from supertokens_python.recipe.emailverification.syncio import ( + verify_email_using_token, + ) + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + token = data["token"] + attempt_account_linking = data.get("attemptAccountLinking", False) + user_context = data.get("userContext", {}) + + response = verify_email_using_token( + tenant_id, token, attempt_account_linking, user_context + ) + + if isinstance(response, VerifyEmailUsingTokenOkResult): + return jsonify( + { + "status": "OK", + "user": { + "email": response.user.email, + "recipeUserId": { + "recipeUserId": response.user.recipe_user_id.get_as_string() + }, + }, + } + ) + else: + return jsonify({"status": "EMAIL_VERIFICATION_INVALID_TOKEN_ERROR"}) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index f97492218..be43bcb46 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -5,7 +5,10 @@ parse_jwt_without_signature_verification, ) from supertokens_python.types import RecipeUserId -from utils import deserialize_validator, get_max_version +from utils import ( # pylint: disable=import-error + deserialize_validator, + get_max_version, +) from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session @@ -175,12 +178,12 @@ def convert_session_to_container(data: Any) -> Session: jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) jwt_payload = jwt_info.payload - user_id = jwt_info.version == 2 and jwt_payload["userId"] or jwt_payload["sub"] + user_id = jwt_payload["userId"] if jwt_info.version == 2 else jwt_payload["sub"] session_handle = jwt_payload["sessionHandle"] recipe_user_id = RecipeUserId(jwt_payload.get("rsub", user_id)) anti_csrf_token = jwt_payload.get("antiCsrfToken") - tenant_id = jwt_info.version >= 4 and jwt_payload["tId"] or "public" + tenant_id = jwt_payload["tId"] if jwt_info.version >= 4 else "public" return Session( recipe_implementation=SessionRecipe.get_instance().recipe_implementation, diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 83b4be17f..188e97285 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -26,7 +26,7 @@ def func(*args): # type: ignore elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): async def func( - i: Any, l: Any, o: Any, u: Any, a: Any + i: Any, l: Any, o: Any, u: Any, a: Any # pylint: disable=unused-argument ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: if ( "()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!1})" diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 073d006ff..1791d24f2 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -2,6 +2,7 @@ from supertokens_python.recipe.session.claims import SessionClaim from supertokens_python.recipe.session.interfaces import SessionClaimValidator +from supertokens_python.types import RecipeUserId, User test_claims: Dict[str, SessionClaim] = {} # type: ignore @@ -56,3 +57,32 @@ def get_max_version(v1: str, v2: str) -> str: return v1 return v2 + + +def serialize_user(user: User, fdi_version: str) -> Dict[str, Any]: + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + return { + "user": { + "id": user.id, + "email": user.emails[0], + "timeJoined": user.time_joined, + "tenantIds": user.tenant_ids, + } + } + else: + return {"user": user.to_json()} + + +def serialize_recipe_user_id( + recipe_user_id: RecipeUserId, fdi_version: str +) -> Dict[str, Any]: + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + return {} + else: + return {"recipeUserId": recipe_user_id.get_as_string()} From 0797226b7b55d680f17dd6758103211e0b151b3c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Wed, 9 Oct 2024 23:44:36 +0530 Subject: [PATCH 092/126] fixes error logging --- .../recipe/accountlinking/interfaces.py | 8 +- .../accountlinking/recipe_implementation.py | 1 + .../recipe/accountlinking/syncio/__init__.py | 7 +- supertokens_python/recipe/session/utils.py | 2 +- tests/test-server/accountlinking.py | 284 ++++++++++++++++++ tests/test-server/app.py | 21 +- tests/test-server/thirdparty.py | 70 +++++ 7 files changed, 383 insertions(+), 10 deletions(-) create mode 100644 tests/test-server/accountlinking.py create mode 100644 tests/test-server/thirdparty.py diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 8b28e4915..7ee318f95 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -220,9 +220,9 @@ def __init__(self, accounts_already_linked: bool, user: User): class LinkAccountsRecipeUserIdAlreadyLinkedError: def __init__( self, - primary_user_id: Optional[str] = None, - user: Optional[User] = None, - description: Optional[str] = None, + primary_user_id: str, + user: User, + description: str, ): self.status: Literal[ "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" @@ -236,14 +236,12 @@ class LinkAccountsAccountInfoAlreadyAssociatedError: def __init__( self, primary_user_id: Optional[str] = None, - user: Optional[User] = None, description: Optional[str] = None, ): self.status: Literal[ "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" self.primary_user_id = primary_user_id - self.user = user self.description = description diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index c9125b2fb..ef156e398 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -278,6 +278,7 @@ async def link_accounts( ): return LinkAccountsRecipeUserIdAlreadyLinkedError( primary_user_id=response["primaryUserId"], + user=response["user"], description=response["description"], ) elif ( diff --git a/supertokens_python/recipe/accountlinking/syncio/__init__.py b/supertokens_python/recipe/accountlinking/syncio/__init__.py index 9153a612a..6de893c1f 100644 --- a/supertokens_python/recipe/accountlinking/syncio/__init__.py +++ b/supertokens_python/recipe/accountlinking/syncio/__init__.py @@ -15,7 +15,8 @@ from supertokens_python.async_to_sync_wrapper import sync -from ..types import AccountInfoWithRecipeId, User, RecipeUserId +from ..types import AccountInfoWithRecipeId +from supertokens_python.types import RecipeUserId from supertokens_python.recipe.session import SessionContainer @@ -24,7 +25,7 @@ def create_primary_user_id_or_link_accounts( recipe_user_id: RecipeUserId, session: Optional[SessionContainer] = None, user_context: Optional[Dict[str, Any]] = None, -) -> User: +): from ..asyncio import ( create_primary_user_id_or_link_accounts as async_create_primary_user_id_or_link_accounts, ) @@ -40,7 +41,7 @@ def get_primary_user_that_can_be_linked_to_recipe_user_id( tenant_id: str, recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, -) -> Optional[User]: +): from ..asyncio import ( get_primary_user_that_can_be_linked_to_recipe_user_id as async_get_primary_user_that_can_be_linked_to_recipe_user_id, ) diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 7c5f1efd5..13f3d8dca 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -567,7 +567,7 @@ def anti_csrf_function( ( overwrite_session_during_sign_in_up if overwrite_session_during_sign_in_up is not None - else True + else False ), ) diff --git a/tests/test-server/accountlinking.py b/tests/test-server/accountlinking.py new file mode 100644 index 000000000..8a17e40b0 --- /dev/null +++ b/tests/test-server/accountlinking.py @@ -0,0 +1,284 @@ +from flask import Flask, request, jsonify +from supertokens_python import async_to_sync_wrapper, convert_to_recipe_user_id +from supertokens_python.recipe.accountlinking.syncio import can_create_primary_user +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.syncio import is_sign_in_allowed +from supertokens_python.recipe.accountlinking.syncio import is_sign_up_allowed +from supertokens_python.recipe.accountlinking.syncio import ( + get_primary_user_that_can_be_linked_to_recipe_user_id, +) +from supertokens_python.recipe.accountlinking.syncio import ( + create_primary_user_id_or_link_accounts, +) +from supertokens_python.recipe.accountlinking.syncio import unlink_account +from supertokens_python.recipe.accountlinking.syncio import is_email_change_allowed +from supertokens_python.recipe.accountlinking.syncio import ( + link_accounts, + create_primary_user, +) +from supertokens_python.recipe.accountlinking.interfaces import ( + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, +) +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import User +from utils import serialize_user # pylint: disable=import-error + + +def add_accountlinking_routes(app: Flask): + @app.route("/test/accountlinking/createprimaryuser", methods=["POST"]) # type: ignore + def create_primary_user_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = create_primary_user( + recipe_user_id, request.json.get("userContext") + ) + if isinstance(response, CreatePrimaryUserOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/linkaccounts", methods=["POST"]) # type: ignore + def link_accounts_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = link_accounts( + recipe_user_id, + request.json["primaryUserId"], + request.json.get("userContext"), + ) + if isinstance(response, LinkAccountsOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "accountsAlreadyLinked": response.accounts_already_linked, + } + ) + elif isinstance(response, LinkAccountsRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, LinkAccountsAccountInfoAlreadyAssociatedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "status": response.status, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/isemailchangeallowed", methods=["POST"]) # type: ignore + def is_email_change_allowed_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = is_email_change_allowed( + recipe_user_id, + request.json["newEmail"], + request.json["isVerified"], + request.json["session"], + request.json.get("userContext"), + ) + return jsonify(response) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/unlinkaccount", methods=["POST"]) # type: ignore + def unlink_account_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = unlink_account( + recipe_user_id, + request.json.get("userContext"), + ) + return jsonify( + { + "status": response.status, + "wasRecipeUserDeleted": response.was_recipe_user_deleted, + "wasLinked": response.was_linked, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/createprimaryuseridorlinkaccounts", methods=["POST"]) # type: ignore + def create_primary_user_id_or_link_accounts_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = create_primary_user_id_or_link_accounts( + request.json["tenantId"], + recipe_user_id, + request.json.get("session", None), + request.json.get("userContext", None), + ) + return jsonify(response.to_json()) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/getprimaryuserthatcanbelinkedtorecipeuserid", methods=["POST"]) # type: ignore + def get_primary_user_that_can_be_linked_to_recipe_user_id_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = get_primary_user_that_can_be_linked_to_recipe_user_id( + request.json["tenantId"], + recipe_user_id, + request.json.get("userContext", None), + ) + return jsonify(response.to_json() if response else None) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/issignupallowed", methods=["POST"]) # type: ignore + def is_signup_allowed_api(): # type: ignore + try: + assert request.json is not None + response = is_sign_up_allowed( + request.json["tenantId"], + AccountInfoWithRecipeId( + recipe_id=request.json["newUser"]["recipeId"], + email=request.json["newUser"]["email"], + phone_number=request.json["newUser"]["phoneNumber"], + third_party=ThirdPartyInfo( + third_party_user_id=request.json["newUser"]["thirdParty"]["id"], + third_party_id=request.json["newUser"]["thirdParty"][ + "thirdPartyId" + ], + ), + ), + request.json["isVerified"], + request.json.get("session", None), + request.json.get("userContext", None), + ) + return jsonify(response) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/issigninallowed", methods=["POST"]) # type: ignore + def is_signin_allowed_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = is_sign_in_allowed( + request.json["tenantId"], + recipe_user_id, + request.json.get("session", None), + request.json.get("userContext", None), + ) + return jsonify(response) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/verifyemailforrecipeuseriflinkedaccountsareverified", methods=["POST"]) # type: ignore + def verify_email_for_recipe_user_if_linked_accounts_are_verified_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + user = User.from_json(request.json["user"]) + async_to_sync_wrapper.sync( + AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=request.json.get("userContext"), + ) + ) + return jsonify({}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/cancreateprimaryuser", methods=["POST"]) # type: ignore + def can_create_primary_user_api(): # type: ignore + try: + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = can_create_primary_user( + recipe_user_id, request.json.get("userContext") + ) + if isinstance(response, CanCreatePrimaryUserOkResult): + return jsonify( + { + "status": response.status, + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance( + response, CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError + ): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "status": response.status, + "primaryUserId": response.primary_user_id, + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 87a63489b..713debd98 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify from supertokens_python.framework import BaseRequest @@ -24,7 +25,9 @@ reset_override_params, ) # pylint: disable=import-error from emailpassword import add_emailpassword_routes # pylint: disable=import-error +from thirdparty import add_thirdparty_routes # pylint: disable=import-error from multitenancy import add_multitenancy_routes # pylint: disable=import-error +from accountlinking import add_accountlinking_routes # pylint: disable=import-error from emailverification import ( add_emailverification_routes, ) # pylint: disable=import-error @@ -105,7 +108,10 @@ async def finalFunction(*args: Any, **kwargs: Any): {"args": args, "kwargs": kwargs}, ) try: - res = await originalFunction(*args, **kwargs) + if inspect.iscoroutinefunction(originalFunction): + res = await originalFunction(*args, **kwargs) + else: + res = originalFunction(*args, **kwargs) override_logging.log_override_event( name + "." + toCamelCase(functionName), "RES", res ) @@ -256,6 +262,9 @@ def init_st(config: Dict[str, Any]): use_dynamic_access_token_signing_key=recipe_config_json.get( "useDynamicAccessTokenSigningKey" ), + overwrite_session_during_sign_in_up=recipe_config_json.get( + "overwriteSessionDuringSignInUp", None + ), override=session.InputOverrideConfig( apis=override_builder_with_logging( "Session.override.apis", @@ -440,6 +449,14 @@ def inner( body: Optional[Dict[str, Any]], user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: + # print( + # "-------------------------------------------!!!!!!!!!!!!!!!!!!!!!!!!!!!" + # ) + # print(url) + # import traceback + # print("Stack trace:") + # traceback.print_stack() + if interceptor_func is not None: resp = interceptor_func( url, method, headers, params, body, user_context @@ -571,6 +588,8 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_multitenancy_routes(app) add_session_routes(app) add_emailverification_routes(app) +add_thirdparty_routes(app) +add_accountlinking_routes(app) init_test_claims() diff --git a/tests/test-server/thirdparty.py b/tests/test-server/thirdparty.py new file mode 100644 index 000000000..dfc0ff009 --- /dev/null +++ b/tests/test-server/thirdparty.py @@ -0,0 +1,70 @@ +from flask import Flask, request, jsonify + +from session import convert_session_to_container # pylint: disable=import-error +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) +from supertokens_python.recipe.thirdparty.syncio import manually_create_or_update_user +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error + + +def add_thirdparty_routes(app: Flask): + @app.route("/test/thirdparty/manuallycreateorupdateuser", methods=["POST"]) # type: ignore + def thirdpartymanuallycreateorupdate(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + third_party_id = data["thirdPartyId"] + third_party_user_id = data["thirdPartyUserId"] + email = data["email"] + is_verified = data["isVerified"] + user_context = data.get("userContext", {}) + + session = None + if data.get("session"): + session = convert_session_to_container(data["session"]) + + response = manually_create_or_update_user( + tenant_id, + third_party_id, + third_party_user_id, + email, + is_verified, + session, + user_context, + ) + + if isinstance(response, ManuallyCreateOrUpdateUserOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, EmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) + elif isinstance(response, SignInUpNotAllowed): + return jsonify(response.to_json()) + elif isinstance(response, SignInUpNotAllowed): + return jsonify(response.to_json()) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) From c1f197511940444b3f4123e66709c1db323f916f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 12:27:57 +0530 Subject: [PATCH 093/126] fixes stuff --- .../recipe/passwordless/interfaces.py | 8 + .../recipe/session/asyncio/__init__.py | 4 +- .../recipe/session/interfaces.py | 5 +- .../recipe/session/recipe_implementation.py | 3 +- .../recipe/session/session_class.py | 1 + .../claims/test_primitive_array_claim.py | 57 +++--- tests/sessions/claims/test_primitive_claim.py | 45 +++-- tests/sessions/claims/utils.py | 10 +- tests/test-server/app.py | 5 +- tests/test-server/passwordless.py | 163 ++++++++++++++++++ tests/test-server/session.py | 39 +++++ tests/test-server/utils.py | 89 ++++++++-- 12 files changed, 351 insertions(+), 78 deletions(-) create mode 100644 tests/test-server/passwordless.py diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 4a92ae6ee..877539069 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -111,6 +111,14 @@ def from_json(json: Dict[str, Any]) -> ConsumedDevice: phone_number=json["phoneNumber"] if "phoneNumber" in json else None, ) + def to_json(self) -> Dict[str, Any]: + return { + "preAuthSessionId": self.pre_auth_session_id, + "failedCodeInputAttemptCount": self.failed_code_input_attempt_count, + "email": self.email, + "phoneNumber": self.phone_number, + } + class ConsumeCodeOkResult: def __init__( diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index 346d111ef..da67f9699 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -123,7 +123,9 @@ async def create_new_session_without_request_response( user_id = user.id for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, recipe_user_id, tenant_id, user_context) + update = await claim.build( + user_id, recipe_user_id, tenant_id, final_access_token_payload, user_context + ) final_access_token_payload = {**final_access_token_payload, **update} return await SessionRecipe.get_instance().recipe_implementation.create_new_session( diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 389455841..4075f8489 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -677,11 +677,8 @@ async def build( recipe_user_id: RecipeUserId, tenant_id: str, current_payload: Dict[str, Any], - user_context: Optional[Dict[str, Any]] = None, + user_context: Dict[str, Any], ) -> JSONObject: - if user_context is None: - user_context = {} - value = await resolve( self.fetch_value( user_id, recipe_user_id, tenant_id, current_payload, user_context diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 9a5fd447d..363259eff 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -139,7 +139,7 @@ async def validate_claims( log_debug_message( "update_claims_in_payload_if_needed %s refetch result %s", validator.id, - json.dumps(value), + value, ) if value is not None: access_token_payload = validator.claim.add_to_payload_( @@ -425,6 +425,7 @@ async def fetch_and_set_claim( session_info.user_id, session_info.recipe_user_id, session_info.tenant_id, + session_info.custom_claims_in_access_token_payload, user_context, ) return await self.merge_into_access_token_payload( diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index cf4ea2b45..6150fb0e7 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -240,6 +240,7 @@ async def fetch_and_set_claim( self.get_user_id(user_context=user_context), self.get_recipe_user_id(user_context=user_context), self.get_tenant_id(user_context=user_context), + self.get_access_token_payload(user_context=user_context), user_context, ) return await self.merge_into_access_token_payload(update, user_context) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index 35c930d93..ef2b8ef78 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from unittest.mock import MagicMock from pytest import fixture, mark @@ -58,30 +58,31 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]): async def test_primitive_claim(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveArrayClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} + user_id = "user_id" + ctx: Dict[str, Any] = {} recipe_user_id = RecipeUserId(user_id) - await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, ctx) + await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, {}, ctx) assert sync_fetch_value.call_count == 1 assert ( user_id, @@ -99,8 +100,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveArrayClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -129,7 +130,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert claim.get_last_refetch_time(payload) == timestamp @@ -153,7 +154,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes(excluded_item).validate(payload, {}) @@ -168,7 +169,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes(included_item).validate(payload, {}) @@ -178,7 +179,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_not_validate_old_values(patch_get_timestamp_ms: MagicMock): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -198,7 +199,7 @@ async def test_should_validate_old_values_if_max_age_is_none_and_default_is_inf( ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -218,7 +219,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert ( await resolve( @@ -233,7 +234,7 @@ async def test_validator_should_refetch_if_value_is_old( ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -252,7 +253,7 @@ async def test_validator_should_not_refetch_if_max_age_is_none_and_default_is_in ): claim = claim_with_inf_max_age payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -271,7 +272,7 @@ async def test_validator_should_validate_values_with_default_max_age( ): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 MINS: @@ -286,7 +287,7 @@ async def test_validator_should_not_refetch_if_max_age_overrides_to_inf( ): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 1 week @@ -321,7 +322,7 @@ async def test_validator_excludes_should_not_validate_empty_payload(): async def test_validator_excludes_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes(included_item).validate(payload, {}) @@ -336,7 +337,7 @@ async def test_validator_excludes_should_not_validate_mismatching_payload(): async def test_validator_excludes_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes(excluded_item).validate(payload, {}) @@ -361,7 +362,7 @@ async def test_validator_includes_all_should_not_validate_empty_payload(): async def test_validator_includes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes_all(excluded_item).validate(payload, {}) @@ -376,7 +377,7 @@ async def test_validator_includes_all_should_not_validate_mismatching_payload(): async def test_validator_includes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.includes_all(included_item).validate(payload, {}) @@ -401,7 +402,7 @@ async def test_validator_excludes_all_should_not_validate_empty_payload(): async def test_validator_excludes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes_all(included_item).validate(payload, {}) @@ -416,7 +417,7 @@ async def test_validator_excludes_all_should_not_validate_mismatching_payload(): async def test_validator_excludes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.excludes_all(excluded_item).validate(payload, {}) @@ -428,7 +429,7 @@ async def test_validator_should_not_validate_older_values_with_5min_default_max_ ): claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 MINS: diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 3fde9dd38..bbfa2d815 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -25,35 +25,32 @@ def teardown_function(_): async def test_primitive_claim(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", RecipeUserId("user_id"), "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} - await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert sync_fetch_value.call_count == 1 assert ( user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, - ctx, + {}, {}, ) == sync_fetch_value.call_args_list[0][ 0 @@ -65,8 +62,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -97,7 +94,7 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert claim.get_last_refetch_time(payload) == timestamp @@ -121,7 +118,7 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val2).validate(payload, {}) @@ -136,7 +133,7 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val).validate(payload, {}) @@ -146,7 +143,7 @@ async def test_validator_should_validate_matching_payload(): async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -166,7 +163,7 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) assert ( await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) @@ -192,7 +189,7 @@ async def test_should_not_validate_empty_payload(): async def test_has_fresh_value_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val2, 600).validate(payload, {}) assert res.is_valid is False @@ -206,7 +203,7 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload(): async def test_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) res = await claim.validators.has_value(val, 600).validate(payload, {}) assert res.is_valid is True @@ -218,7 +215,7 @@ async def test_should_not_validate_old_values_as_well( claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -236,14 +233,16 @@ async def test_should_refetch_if_value_is_not_set(): async def test_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) assert claim.validators.has_value(val2, 600).should_refetch(payload, {}) is False async def test_should_refetch_if_value_is_old(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}) + payload = await claim.build( + "userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -256,7 +255,7 @@ async def test_should_not_validate_old_values_as_well_with_default_max_age_provi ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins payload = await claim.build( - "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {} + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} ) # Increase clock time by 10 mins: @@ -275,7 +274,7 @@ async def test_should_refetch_if_value_is_old_with_default_max_age_provided( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("userId", RecipeUserId("userId"), "public", {}) + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 92a78c96d..b66fd6d68 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -31,11 +31,15 @@ async def new_create_new_session( tenant_id: str, user_context: Dict[str, Any], ): - payload_update = await claim.build( - user_id, RecipeUserId(user_id), tenant_id, user_context - ) if access_token_payload is None: access_token_payload = {} + payload_update = await claim.build( + user_id, + RecipeUserId(user_id), + tenant_id, + access_token_payload, + user_context, + ) access_token_payload = { **access_token_payload, **payload_update, diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 713debd98..8a4a3a023 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -7,7 +7,7 @@ from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe -from utils import init_test_claims # pylint: disable=import-error +from passwordless import add_passwordless_routes # pylint: disable=import-error from supertokens_python.process_state import ProcessState from supertokens_python.recipe.dashboard.recipe import DashboardRecipe from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe @@ -590,8 +590,7 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_emailverification_routes(app) add_thirdparty_routes(app) add_accountlinking_routes(app) - -init_test_claims() +add_passwordless_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/passwordless.py b/tests/test-server/passwordless.py new file mode 100644 index 000000000..4edb510f5 --- /dev/null +++ b/tests/test-server/passwordless.py @@ -0,0 +1,163 @@ +from flask import Flask, request, jsonify +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.recipe.passwordless.interfaces import ( + ConsumeCodeExpiredUserInputCodeError, + ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeOkResult, + ConsumeCodeRestartFlowError, + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.passwordless.syncio import ( + signinup, + create_code, + update_user, + consume_code, +) +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error +from session import convert_session_to_container # pylint: disable=import-error + + +def add_passwordless_routes(app: Flask): + @app.route("/test/passwordless/signinup", methods=["POST"]) # type: ignore + def sign_in_up_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = signinup( + email=body.get("email", None), + phone_number=body.get("phoneNumber", None), + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user(response.user, request.headers.get("fdi-version", "")), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + + @app.route("/test/passwordless/createcode", methods=["POST"]) # type: ignore + def create_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = create_code( + email=body.get("email"), + phone_number=body.get("phoneNumber"), + tenant_id=body.get("tenantId", "public"), + user_input_code=body.get("userInputCode"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "codeId": response.code_id, + "preAuthSessionId": response.pre_auth_session_id, + "codeLifeTime": response.code_life_time, + "deviceId": response.device_id, + "linkCode": response.link_code, + "timeCreated": response.time_created, + "userInputCode": response.user_input_code, + } + ) + + @app.route("/test/passwordless/consumecode", methods=["POST"]) # type: ignore + def consume_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = consume_code( + device_id=body["deviceId"], + pre_auth_session_id=body["preAuthSessionId"], + user_input_code=body.get("userInputCode"), + link_code=body["linkCode"], + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + + if isinstance(response, ConsumeCodeOkResult): + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, ConsumeCodeIncorrectUserInputCodeError): + return jsonify( + { + "status": "INCORRECT_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeExpiredUserInputCodeError): + return jsonify( + { + "status": "EXPIRED_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeRestartFlowError): + return jsonify({"status": "RESTART_FLOW_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) + + @app.route("/test/passwordless/updateuser", methods=["POST"]) # type: ignore + def update_user_api(): # type: ignore + assert request.json is not None + body = request.json + response = update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + user_context=body.get("userContext"), + ) + + if isinstance(response, UpdateUserOkResult): + return jsonify({"status": "OK"}) + elif isinstance(response, UpdateUserUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(response, UpdateUserEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, UpdateUserPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, EmailChangeNotAllowedError): + return jsonify({"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR"}) + else: + return jsonify({"status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR"}) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index be43bcb46..00c5dcda3 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,5 +1,6 @@ from typing import Any from flask import Flask, request, jsonify +from override_logging import log_override_event # pylint: disable=import-error from supertokens_python.recipe.session.interfaces import TokenInfo from supertokens_python.recipe.session.jwt import ( parse_jwt_without_signature_verification, @@ -12,6 +13,7 @@ from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session +from utils import deserialize_claim # pylint: disable=import-error def add_session_routes(app: Flask): @@ -173,6 +175,43 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/sessionobject/fetchandsetclaim", methods=["POST"]) # type: ignore + def fetch_and_set_claim_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.fetchandsetclaim", "CALL", data) + session = convert_session_to_container(data) + + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + session.sync_fetch_and_set_claim(claim, user_context) + response = { + "updatedSession": { + "sessionHandle": session.get_handle(), + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + "tenantId": session.get_tenant_id(), + "userDataInAccessToken": session.get_access_token_payload(), + "accessToken": session.get_access_token(), + "frontToken": session.get_all_session_tokens_dangerously()[ + "frontToken" + ], + "refreshToken": session.get_all_session_tokens_dangerously()[ + "refreshToken" + ], + "antiCsrfToken": session.get_all_session_tokens_dangerously()[ + "antiCsrfToken" + ], + "accessTokenUpdated": session.get_all_session_tokens_dangerously()[ + "accessAndFrontTokenUpdated" + ], + } + } + return jsonify(response) + def convert_session_to_container(data: Any) -> Session: jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 1791d24f2..1d70e8035 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -1,26 +1,85 @@ -from typing import Any, Dict +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) from supertokens_python.recipe.session.claims import SessionClaim from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.types import RecipeUserId, User - -test_claims: Dict[str, SessionClaim] = {} # type: ignore - - -def init_test_claims(): - add_builtin_claims() - - -def add_builtin_claims(): - from supertokens_python.recipe.emailverification import EmailVerificationClaim - - test_claims[EmailVerificationClaim.key] = EmailVerificationClaim +from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session.claims import BooleanClaim +from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.userroles import UserRoleClaim +from supertokens_python.recipe.userroles import PermissionClaim +from typing import Any, Dict +from supertokens_python.recipe.session.claims import PrimitiveClaim + + +def mock_claim_builder(key: str, values: Any) -> PrimitiveClaim[Any]: + def fetch_value( + user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> Any: + log_override_event( + f"claim-{key}.fetchValue", + "CALL", + { + "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), + "tenantId": tenant_id, + "currentPayload": current_payload, + "userContext": user_context, + }, + ) + + ret_val: Any = user_context.get("st-stub-arr-value") or ( + values[0] + if isinstance(values, list) and isinstance(values[0], list) + else values + ) + log_override_event(f"claim-{key}.fetchValue", "RES", ret_val) + + return ret_val + + return PrimitiveClaim(key=key or "st-stub-primitive", fetch_value=fetch_value) + + +test_claim_setups: Dict[str, SessionClaim[Any]] = { + "st-true": BooleanClaim( + key="st-true", + fetch_value=lambda *_args, **_kwargs: True, # type: ignore + ), + "st-undef": BooleanClaim( + key="st-undef", + fetch_value=lambda *_args, **_kwargs: None, # type: ignore + ), +} + +# Add all built-in claims +for claim in [ + EmailVerificationClaim, + MultiFactorAuthClaim, + UserRoleClaim, + PermissionClaim, +]: + test_claim_setups[claim.key] = claim # type: ignore + + +def deserialize_claim(serialized_claim: Dict[str, Any]) -> SessionClaim[Any]: + key = serialized_claim["key"] + + if key.startswith("st-stub-"): + return mock_claim_builder(key.replace("st-stub-", "", 1), serialized_claim) + + return test_claim_setups[key] def deserialize_validator(validatorsInput: Any) -> SessionClaimValidator: # type: ignore key = validatorsInput["key"] - if key in test_claims: - claim = test_claims[key] # type: ignore + if key in test_claim_setups: + claim = test_claim_setups[key] validator_name = validatorsInput["validatorName"] if hasattr(claim.validators, toSnakeCase(validator_name)): # type: ignore validator_func = getattr(claim.validators, toSnakeCase(validator_name)) # type: ignore From 29a896951b9965cb123b3a2a604f1134a4a260af Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 12:45:53 +0530 Subject: [PATCH 094/126] adds more apis --- tests/test-server/app.py | 73 ++++++++++++++++++++++++++++++++++++++- tests/test-server/totp.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 tests/test-server/totp.py diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 8a4a3a023..2cf199fb4 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -3,7 +3,13 @@ from flask import Flask, request, jsonify from supertokens_python.framework import BaseRequest from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig -from supertokens_python.recipe import accountlinking, multifactorauth +from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig +from supertokens_python.recipe import ( + accountlinking, + multifactorauth, + passwordless, + totp, +) from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe from supertokens_python.recipe.totp.recipe import TOTPRecipe @@ -24,6 +30,7 @@ get_override_params, reset_override_params, ) # pylint: disable=import-error +from totp import add_totp_routes # pylint: disable=import-error from emailpassword import add_emailpassword_routes # pylint: disable=import-error from thirdparty import add_thirdparty_routes # pylint: disable=import-error from multitenancy import add_multitenancy_routes # pylint: disable=import-error @@ -424,6 +431,69 @@ def init_st(config: Dict[str, Any]): ), ) ) + elif recipe_id == "passwordless": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + contact_config: passwordless.ContactConfig = ( + passwordless.ContactEmailOnlyConfig() + ) + if recipe_config_json.get("contactMethod") == "PHONE": + contact_config = passwordless.ContactPhoneOnlyConfig() + elif recipe_config_json.get("contactMethod") == "EMAIL_OR_PHONE": + contact_config = passwordless.ContactEmailOrPhoneConfig() + + recipe_list.append( + passwordless.init( + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "Passwordless.emailDelivery.override", + config.get("emailDelivery", {}).get("override", None), + ) + ), + sms_delivery=SMSDeliveryConfig( + override=override_builder_with_logging( + "Passwordless.smsDelivery.override", + config.get("smsDelivery", {}).get("override", None), + ) + ), + contact_config=contact_config, + flow_type=recipe_config_json.get("flowType"), + override=passwordless.InputOverrideConfig( + apis=override_builder_with_logging( + "Passwordless.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "Passwordless.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ), + ) + ) + elif recipe_id == "totp": + from supertokens_python.recipe.totp.types import ( + OverrideConfig as TOTPOverrideConfig, + ) + + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + totp.init( + config=totp.TOTPConfig( + default_period=recipe_config_json.get("defaultPeriod"), + default_skew=recipe_config_json.get("defaultSkew"), + issuer=recipe_config_json.get("issuer"), + override=TOTPOverrideConfig( + apis=override_builder_with_logging( + "Multitenancy.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "Multitenancy.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ), + ) + ) + ) interceptor_func = None if config.get("supertokens", {}).get("networkInterceptor") is not None: @@ -591,6 +661,7 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_thirdparty_routes(app) add_accountlinking_routes(app) add_passwordless_routes(app) +add_totp_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/totp.py b/tests/test-server/totp.py new file mode 100644 index 000000000..46e8af537 --- /dev/null +++ b/tests/test-server/totp.py @@ -0,0 +1,52 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.totp.syncio import create_device, verify_device +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, +) +from supertokens_python.recipe.totp.types import ( + DeviceAlreadyExistsError, + InvalidTOTPError, + UnknownDeviceError, + VerifyDeviceOkResult, +) + + +def add_totp_routes(app: Flask): + @app.route("/test/totp/createdevice", methods=["POST"]) # type: ignore + def create_device_api(): # type: ignore + assert request.json is not None + body = request.json + response = create_device( + user_id=body.get("userId"), + user_identifier_info=body.get("userIdentifierInfo"), + device_name=body.get("deviceName"), + skew=body.get("skew"), + period=body.get("period"), + user_context=body.get("userContext"), + ) + if isinstance(response, CreateDeviceOkResult): + return jsonify(response.to_json()) + elif isinstance(response, DeviceAlreadyExistsError): + return jsonify(response.to_json()) + else: + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + + @app.route("/test/totp/verifydevice", methods=["POST"]) # type: ignore + def verify_device_api(): # type: ignore + assert request.json is not None + body = request.json + response = verify_device( + tenant_id=body.get("tenantId"), + user_id=body.get("userId"), + device_name=body.get("deviceName"), + totp=body.get("totp"), + user_context=body.get("userContext"), + ) + if isinstance(response, VerifyDeviceOkResult): + return jsonify(response.to_json()) + elif isinstance(response, UnknownDeviceError): + return jsonify(response.to_json()) + elif isinstance(response, InvalidTOTPError): + return jsonify(response.to_json()) + else: + return jsonify(response.to_json()) From 81cf53dc35722c246f3e38636e6a7b01fb30cc09 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 19:25:20 +0530 Subject: [PATCH 095/126] fixes more tests --- tests/test-server/app.py | 10 +- tests/test-server/session.py | 62 ++++---- tests/test-server/supertokens.py | 89 +++++++++++ tests/test-server/test_functions_mapper.py | 164 ++++++++++++++++++++- 4 files changed, 293 insertions(+), 32 deletions(-) create mode 100644 tests/test-server/supertokens.py diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 2cf199fb4..bd7489c89 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -141,7 +141,6 @@ def builder(oI: T) -> T: for attr in dir(oI) if callable(getattr(oI, attr)) and not attr.startswith("__") ] - for member in members: create_override(oI, member, name, override_name) return oI @@ -340,7 +339,7 @@ def init_st(config: Dict[str, Any]): ), ) - include_in_non_public_tenants_by_default = None + include_in_non_public_tenants_by_default = False if "includeInNonPublicTenantsByDefault" in provider: include_in_non_public_tenants_by_default = provider[ @@ -391,6 +390,10 @@ def init_st(config: Dict[str, Any]): ), ), include_in_non_public_tenants_by_default=include_in_non_public_tenants_by_default, + override=override_builder_with_logging( + "ThirdParty.providers.override", + provider.get("override", None), + ), ) providers.append(provider_input) recipe_list.append( @@ -662,6 +665,9 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_accountlinking_routes(app) add_passwordless_routes(app) add_totp_routes(app) +from supertokens import add_supertokens_routes # pylint: disable=import-error + +add_supertokens_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 00c5dcda3..20bc00807 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,6 +1,7 @@ -from typing import Any +from typing import Any, Dict from flask import Flask, request, jsonify from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import TokenInfo from supertokens_python.recipe.session.jwt import ( parse_jwt_without_signature_verification, @@ -50,27 +51,7 @@ def create_new_session_without_request_response(): # type: ignore user_context, ) - return jsonify( - { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - } - ) + return jsonify(convert_session_to_json(session_container)) @app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore def get_session_without_request_response(): # type: ignore @@ -83,13 +64,14 @@ def get_session_without_request_response(): # type: ignore options = data.get("options") user_context = data.get("userContext", {}) - try: - session_container = session.get_session_without_request_response( - access_token, anti_csrf_token, options, user_context - ) - return jsonify(session_container) - except Exception as e: - return jsonify({"error": str(e)}), 500 + session_container = session.get_session_without_request_response( + access_token, anti_csrf_token, options, user_context + ) + return jsonify( + None + if session_container is None + else convert_session_to_json(session_container) + ) @app.route("/test/session/sessionobject/assertclaims", methods=["POST"]) # type: ignore def assert_claims(): # type: ignore @@ -213,6 +195,28 @@ def fetch_and_set_claim_api(): # type: ignore return jsonify(response) +def convert_session_to_json(session_container: SessionContainer) -> Dict[str, Any]: + return { + "sessionHandle": session_container.get_handle(), + "userId": session_container.get_user_id(), + "tenantId": session_container.get_tenant_id(), + "userDataInAccessToken": session_container.get_access_token_payload(), + "accessToken": session_container.get_access_token(), + "frontToken": session_container.get_all_session_tokens_dangerously()[ + "frontToken" + ], + "refreshToken": session_container.get_all_session_tokens_dangerously()[ + "refreshToken" + ], + "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ + "antiCsrfToken" + ], + "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ + "accessAndFrontTokenUpdated" + ], + } + + def convert_session_to_container(data: Any) -> Session: jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) jwt_payload = jwt_info.payload diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py new file mode 100644 index 000000000..3cc40d508 --- /dev/null +++ b/tests/test-server/supertokens.py @@ -0,0 +1,89 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import AccountInfo +from supertokens_python.syncio import ( + get_user, + delete_user, + list_users_by_account_info, + get_users_newest_first, + get_users_oldest_first, +) + + +def add_supertokens_routes(app: Flask): + @app.route("/test/supertokens/getuser", methods=["POST"]) # type: ignore + def get_user_api(): # type: ignore + assert request.json is not None + response = get_user(request.json["userId"], request.json.get("userContext")) + return jsonify(None if response is None else response.to_json()) + + @app.route("/test/supertokens/deleteuser", methods=["POST"]) # type: ignore + def delete_user_api(): # type: ignore + assert request.json is not None + delete_user( + request.json["userId"], + request.json["removeAllLinkedAccounts"], + request.json.get("userContext"), + ) + return jsonify({"status": "OK"}) + + @app.route("/test/supertokens/listusersbyaccountinfo", methods=["POST"]) # type: ignore + def list_users_by_account_info_api(): # type: ignore + assert request.json is not None + response = list_users_by_account_info( + request.json["tenantId"], + AccountInfo( + email=request.json["accountInfo"].get("email", None), + phone_number=request.json["accountInfo"].get("phoneNumber", None), + third_party=( + None + if "thirdParty" not in request.json["accountInfo"] + else ThirdPartyInfo( + third_party_id=request.json["accountInfo"]["thirdParty"][ + "thirdPartyId" + ], + third_party_user_id=request.json["accountInfo"]["thirdParty"][ + "id" + ], + ) + ), + ), + request.json["doUnionOfAccountInfo"], + request.json.get("userContext"), + ) + + return jsonify([r.to_json() for r in response]) + + @app.route("/test/supertokens/getusersnewestfirst", methods=["POST"]) # type: ignore + def get_users_newest_first_api(): # type: ignore + assert request.json is not None + response = get_users_newest_first( + include_recipe_ids=request.json["includeRecipeIds"], + limit=request.json["limit"], + pagination_token=request.json["paginationToken"], + tenant_id=request.json["tenantId"], + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) + + @app.route("/test/supertokens/getusersoldestfirst", methods=["POST"]) # type: ignore + def get_users_oldest_first_api(): # type: ignore + assert request.json is not None + response = get_users_oldest_first( + include_recipe_ids=request.json["includeRecipeIds"], + limit=request.json["limit"], + pagination_token=request.json["paginationToken"], + tenant_id=request.json["tenantId"], + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 188e97285..63f32d0f4 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -6,6 +6,11 @@ ShouldAutomaticallyLink, ShouldNotAutomaticallyLink, ) +from supertokens_python.recipe.thirdparty.types import ( + RawUserInfoFromProvider, + UserInfo, + UserInfoEmail, +) from supertokens_python.types import AccountInfo, RecipeUserId from supertokens_python.types import APIResponse, User @@ -124,7 +129,164 @@ async def func( return func - raise Exception("Unknown eval string") + if eval_str.startswith("thirdparty.init.signInAndUpFeature.providers"): + + def custom_provider(provider: Any): + if "custom-ev" in eval_str: + + def exchange_auth_code_for_oauth_tokens1( + redirect_uri_info: Any, # pylint: disable=unused-argument + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return {} + + def get_user_info1( + oauth_tokens: Any, + user_context: Any, # pylint: disable=unused-argument + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens1 + ) + provider.get_user_info = get_user_info1 + return provider + + if "custom-no-ev" in eval_str: + + def exchange_auth_code_for_oauth_tokens2( + redirect_uri_info: Any, # pylint: disable=unused-argument + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return {} + + def get_user_info2( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=False, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens2 + ) + provider.get_user_info = get_user_info2 + return provider + + if "custom2" in eval_str: + + def exchange_auth_code_for_oauth_tokens3( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info["redirectURIQueryParams"] + + def get_user_info3( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=f"custom2{oauth_tokens['email']}", + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens3 + ) + provider.get_user_info = get_user_info3 + return provider + + if "custom3" in eval_str: + + def exchange_auth_code_for_oauth_tokens4( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info["redirectURIQueryParams"] + + def get_user_info4( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens["email"], + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens4 + ) + provider.get_user_info = get_user_info4 + return provider + + if "custom" in eval_str: + + def exchange_auth_code_for_oauth_tokens5( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info + + def get_user_info5( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + if oauth_tokens.get("error"): + raise Exception("Credentials error") + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "userId"), + email=( + None + if oauth_tokens.get("email") is None + else UserInfoEmail( + email=oauth_tokens.get("email"), + is_verified=oauth_tokens.get("isVerified", False), + ) + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens5 + ) + provider.get_user_info = get_user_info5 + return provider + + return custom_provider + + raise Exception("Unknown eval string: " + eval_str) class OverrideParams(APIResponse): From c1ec3c95128cb3cca28dbe4ca10892dd6b441021 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 19:34:16 +0530 Subject: [PATCH 096/126] fixes bugs --- .../recipe/accountlinking/recipe_implementation.py | 6 +++--- tests/test-server/test_functions_mapper.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index ef156e398..5d4495c5b 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -118,7 +118,7 @@ async def can_create_primary_user( ) elif ( response["status"] - == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ): return CanCreatePrimaryUserAccountInfoAlreadyAssociatedError( response["primaryUserId"], response["description"] @@ -155,7 +155,7 @@ async def create_primary_user( ) elif ( response["status"] - == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ): return CreatePrimaryUserAccountInfoAlreadyAssociatedError( response["primaryUserId"], response["description"] @@ -194,7 +194,7 @@ async def can_link_accounts( ) elif ( response["status"] - == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_PRIMARY_USER_ID_ERROR" + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" ): return CanLinkAccountsAccountInfoAlreadyAssociatedError( response["primaryUserId"], response["description"] diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 63f32d0f4..3b744410f 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -71,7 +71,7 @@ async def func( ): if a.get("DO_NOT_LINK"): return ShouldNotAutomaticallyLink() - if i.get("email") == "test2@example.com" and l is None: + if i.email == "test2@example.com" and l is None: return ShouldNotAutomaticallyLink() return ShouldAutomaticallyLink(should_require_verification=False) @@ -101,7 +101,7 @@ async def func( ): if a.get("DO_NOT_LINK"): return ShouldNotAutomaticallyLink() - if i.get("email") == "test2@example.com" and l is None: + if i.email == "test2@example.com" and l is None: return ShouldNotAutomaticallyLink() return ShouldAutomaticallyLink(should_require_verification=True) @@ -109,9 +109,9 @@ async def func( 'async(i,e)=>{if("emailpassword"===i.recipeId){if(!((await supertokens.listUsersByAccountInfo("public",{email:i.email})).length>1))return{shouldAutomaticallyLink:!1}}return{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}}' in eval_str ): - if i.get("recipeId") == "emailpassword": + if i.recipe_id == "emailpassword": users = await list_users_by_account_info( - "public", AccountInfo(email=i.get("email")) + "public", AccountInfo(email=i.email) ) if len(users) <= 1: return ShouldNotAutomaticallyLink() From edfef833860291a5dbebd850b2e1c5dc402ebcfc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 19:39:47 +0530 Subject: [PATCH 097/126] fixes more stuff --- tests/test-server/test_functions_mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 3b744410f..beb883be1 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -81,7 +81,7 @@ async def func( ): if a.get("DO_NOT_LINK"): return ShouldNotAutomaticallyLink() - if l is not None and l.get("id") == o.getUserId(): + if l is not None and l.id == o.get_user_id(): return ShouldNotAutomaticallyLink() return ShouldAutomaticallyLink(should_require_verification=False) @@ -91,7 +91,7 @@ async def func( ): if a.get("DO_NOT_LINK"): return ShouldNotAutomaticallyLink() - if l is not None and l.get("id") == o.getUserId(): + if l is not None and l.id == o.get_user_id(): return ShouldNotAutomaticallyLink() return ShouldAutomaticallyLink(should_require_verification=True) From d8515942f0dd76abea08a5fb9ba19acf5cf1ec1f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 10 Oct 2024 20:59:19 +0530 Subject: [PATCH 098/126] fixes more bugs --- .pylintrc | 5 +- .../recipe/passwordless/api/implementation.py | 2 +- tests/test-server/app.py | 34 +++++- tests/test-server/test_functions_mapper.py | 100 ++++++++++++++++-- 4 files changed, 129 insertions(+), 12 deletions(-) diff --git a/.pylintrc b/.pylintrc index c271a2fb9..0c8c9f708 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,4 @@ [MASTER] - # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. @@ -20,11 +19,11 @@ fail-on= fail-under=10.0 # Files or directories to be skipped. They should be base names, not paths. -ignore=CVS +ignore=CVS, # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths and can be in Posix or Windows format. -ignore-paths= +ignore-paths=tests/test-server # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index fb6b49db0..3aa530f26 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -652,7 +652,7 @@ async def check_credentials(_: str): if not isinstance(pre_auth_checks_result, OkResponse): if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): - reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] assert isinstance(reason, str) return SignInUpPostNotAllowedResponse(reason) if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): diff --git a/tests/test-server/app.py b/tests/test-server/app.py index bd7489c89..9515a90c7 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -444,19 +444,35 @@ def init_st(config: Dict[str, Any]): elif recipe_config_json.get("contactMethod") == "EMAIL_OR_PHONE": contact_config = passwordless.ContactEmailOrPhoneConfig() + class EmailDeliveryCustom(passwordless.EmailDeliveryInterface[Any]): + async def send_email( + self, template_vars: Any, user_context: Dict[str, Any] + ) -> None: + f = get_func("passwordless.init.emailDelivery.service.sendEmail") + return f(template_vars, user_context) + + class SMSDeliveryCustom(passwordless.SMSDeliveryInterface[Any]): + async def send_sms( + self, template_vars: Any, user_context: Dict[str, Any] + ) -> None: + f = get_func("passwordless.init.smsDelivery.service.sendSms") + return f(template_vars, user_context) + recipe_list.append( passwordless.init( email_delivery=EmailDeliveryConfig( + service=EmailDeliveryCustom(), override=override_builder_with_logging( "Passwordless.emailDelivery.override", config.get("emailDelivery", {}).get("override", None), - ) + ), ), sms_delivery=SMSDeliveryConfig( + service=SMSDeliveryCustom(), override=override_builder_with_logging( "Passwordless.smsDelivery.override", config.get("smsDelivery", {}).get("override", None), - ) + ), ), contact_config=contact_config, flow_type=recipe_config_json.get("flowType"), @@ -657,6 +673,20 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404 +import traceback +from flask import jsonify + + +@app.errorhandler(Exception) # type: ignore +def handle_exception(e: Exception): + # Print the error and stack trace + print(f"An error occurred: {str(e)}") + traceback.print_exc() + + # Return JSON response with 500 status code + return jsonify({"error": "Internal Server Error", "message": str(e)}), 500 + + add_emailpassword_routes(app) add_multitenancy_routes(app) add_session_routes(app) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index beb883be1..fcdd1644d 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -6,6 +6,8 @@ ShouldAutomaticallyLink, ShouldNotAutomaticallyLink, ) +from supertokens_python.recipe.dashboard.interfaces import APIOptions +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.types import ( RawUserInfoFromProvider, UserInfo, @@ -20,6 +22,9 @@ class Info: def get_func(eval_str: str) -> Callable[..., Any]: + global store # pylint: disable=global-variable-not-assigned + global send_email_inputs # pylint: disable=global-variable-not-assigned + global send_sms_inputs # pylint: disable=global-variable-not-assigned if eval_str.startswith("supertokens.init.supertokens.networkInterceptor"): def func(*args): # type: ignore @@ -28,6 +33,89 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("passwordless.init.emailDelivery.service.sendEmail"): + + def func1( + template_vars: Any, + user_context: Dict[str, Any], # pylint: disable=unused-argument + ) -> None: # pylint: disable=unused-argument + # Add to store + jsonified = { + "codeLifeTime": template_vars.code_life_time, + "email": template_vars.email, + "isFirstFactor": template_vars.is_first_factor, + "preAuthSessionId": template_vars.pre_auth_session_id, + "tenantId": template_vars.tenant_id, + "urlWithLinkCode": template_vars.url_with_link_code, + "userInputCode": template_vars.user_input_code, + } + jsonified = {k: v for k, v in jsonified.items() if v is not None} + if "emailInputs" in store: + store["emailInputs"].append(jsonified) + else: + store["emailInputs"] = [jsonified] + + # Add to send_email_inputs + send_email_inputs.append(jsonified) + + return func1 + + elif eval_str.startswith("passwordless.init.smsDelivery.service.sendSms"): + + def func2( + template_vars: Any, user_context: Dict[str, Any] + ) -> None: # pylint: disable=unused-argument + jsonified = { + "codeLifeTime": template_vars.code_life_time, + "phoneNumber": template_vars.phone_number, + "isFirstFactor": template_vars.is_first_factor, + "preAuthSessionId": template_vars.pre_auth_session_id, + "tenantId": template_vars.tenant_id, + "urlWithLinkCode": template_vars.url_with_link_code, + "userInputCode": template_vars.user_input_code, + } + jsonified = {k: v for k, v in jsonified.items() if v is not None} + send_sms_inputs.append(jsonified) + + return func2 + + elif eval_str.startswith("passwordless.init.override.apis"): + + def func3(oI: Any) -> Dict[str, Any]: + og = oI.consume_code_post + + async def consume_code_post( + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], + ) -> Any: + o = await api_options.request.json() + assert o is not None + if o.get("user_context", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = o["user_context"]["DO_LINK"] + return await og( + pre_auth_session_id, + user_input_code, + device_id, + link_code, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, + ) + + oI.consume_code_post = consume_code_post + return oI + + return func3 + elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): async def func( @@ -299,7 +387,7 @@ def __init__( send_email_callback_called: Optional[bool] = None, send_email_to_user_email: Optional[str] = None, send_email_inputs: Optional[List[str]] = None, - send_sms_inputs: Optional[List[str]] = None, + send_sms_inputs: List[str] = [], # pylint: disable=dangerous-default-value send_email_to_recipe_user_id: Optional[str] = None, user_in_callback: Optional[User] = None, email: Optional[str] = None, @@ -308,7 +396,7 @@ def __init__( user_id_in_callback: Optional[str] = None, recipe_user_id_in_callback: Optional[str] = None, core_call_count: int = 0, - store: Optional[Any] = None, + store: Dict[str, Any] = {}, # pylint: disable=dangerous-default-value ): self.send_email_to_user_id = send_email_to_user_id self.token = token @@ -409,7 +497,7 @@ def reset_override_params(): new_account_info_in_callback = None user_id_in_callback = None recipe_user_id_in_callback = None - store = None + store = {} Info.core_call_count = 0 @@ -419,8 +507,8 @@ def reset_override_params(): email_post_password_reset: Optional[str] = None send_email_callback_called: bool = False send_email_to_user_email: Optional[str] = None -send_email_inputs: List[str] = [] -send_sms_inputs: List[str] = [] +send_email_inputs: List[Any] = [] +send_sms_inputs: List[Any] = [] send_email_to_recipe_user_id: Optional[str] = None user_in_callback: Optional[User] = None email: Optional[str] = None @@ -428,4 +516,4 @@ def reset_override_params(): new_account_info_in_callback: Optional[RecipeLevelUser] = None user_id_in_callback: Optional[str] = None recipe_user_id_in_callback: Optional[RecipeUserId] = None -store: Optional[str] = None +store: Dict[str, Any] = {} From 579ac9a12647383cbcb6eadba0286a2d5a02537c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 13:42:51 +0530 Subject: [PATCH 099/126] fixes more issues --- tests/test-server/accountlinking.py | 409 ++++++++++++------------- tests/test-server/emailpassword.py | 18 +- tests/test-server/emailverification.py | 66 ++++ 3 files changed, 271 insertions(+), 222 deletions(-) diff --git a/tests/test-server/accountlinking.py b/tests/test-server/accountlinking.py index 8a17e40b0..92fb0b330 100644 --- a/tests/test-server/accountlinking.py +++ b/tests/test-server/accountlinking.py @@ -29,256 +29,235 @@ from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.types import User from utils import serialize_user # pylint: disable=import-error +from session import convert_session_to_container def add_accountlinking_routes(app: Flask): @app.route("/test/accountlinking/createprimaryuser", methods=["POST"]) # type: ignore def create_primary_user_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = create_primary_user( - recipe_user_id, request.json.get("userContext") + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = create_primary_user(recipe_user_id, request.json.get("userContext")) + if isinstance(response, CreatePrimaryUserOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } ) - if isinstance(response, CreatePrimaryUserOkResult): - return jsonify( - { - "status": "OK", - **serialize_user( - response.user, request.headers.get("fdi-version", "") - ), - "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, - } - ) - elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - } - ) - elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - } - ) - else: - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - } - ) - except Exception as e: - return jsonify({"error": str(e)}), 500 @app.route("/test/accountlinking/linkaccounts", methods=["POST"]) # type: ignore def link_accounts_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = link_accounts( - recipe_user_id, - request.json["primaryUserId"], - request.json.get("userContext"), + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = link_accounts( + recipe_user_id, + request.json["primaryUserId"], + request.json.get("userContext"), + ) + if isinstance(response, LinkAccountsOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "accountsAlreadyLinked": response.accounts_already_linked, + } ) - if isinstance(response, LinkAccountsOkResult): - return jsonify( - { - "status": "OK", - **serialize_user( - response.user, request.headers.get("fdi-version", "") - ), - "accountsAlreadyLinked": response.accounts_already_linked, - } - ) - elif isinstance(response, LinkAccountsRecipeUserIdAlreadyLinkedError): - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - **serialize_user( - response.user, request.headers.get("fdi-version", "") - ), - } - ) - elif isinstance(response, LinkAccountsAccountInfoAlreadyAssociatedError): - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - } - ) - else: - return jsonify( - { - "description": response.description, - "status": response.status, - } - ) - except Exception as e: - return jsonify({"error": str(e)}), 500 - - @app.route("/test/accountlinking/isemailchangeallowed", methods=["POST"]) # type: ignore - def is_email_change_allowed_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = is_email_change_allowed( - recipe_user_id, - request.json["newEmail"], - request.json["isVerified"], - request.json["session"], - request.json.get("userContext"), + elif isinstance(response, LinkAccountsRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + } ) - return jsonify(response) - except Exception as e: - return jsonify({"error": str(e)}), 500 - - @app.route("/test/accountlinking/unlinkaccount", methods=["POST"]) # type: ignore - def unlink_account_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = unlink_account( - recipe_user_id, - request.json.get("userContext"), + elif isinstance(response, LinkAccountsAccountInfoAlreadyAssociatedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } ) + else: return jsonify( { + "description": response.description, "status": response.status, - "wasRecipeUserDeleted": response.was_recipe_user_deleted, - "wasLinked": response.was_linked, } ) - except Exception as e: - return jsonify({"error": str(e)}), 500 + + @app.route("/test/accountlinking/isemailchangeallowed", methods=["POST"]) # type: ignore + def is_email_change_allowed_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_email_change_allowed( + recipe_user_id, + request.json["newEmail"], + request.json["isVerified"], + session, + request.json.get("userContext"), + ) + return jsonify(response) + + @app.route("/test/accountlinking/unlinkaccount", methods=["POST"]) # type: ignore + def unlink_account_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = unlink_account( + recipe_user_id, + request.json.get("userContext"), + ) + return jsonify( + { + "status": response.status, + "wasRecipeUserDeleted": response.was_recipe_user_deleted, + "wasLinked": response.was_linked, + } + ) @app.route("/test/accountlinking/createprimaryuseridorlinkaccounts", methods=["POST"]) # type: ignore def create_primary_user_id_or_link_accounts_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = create_primary_user_id_or_link_accounts( - request.json["tenantId"], - recipe_user_id, - request.json.get("session", None), - request.json.get("userContext", None), - ) - return jsonify(response.to_json()) - except Exception as e: - return jsonify({"error": str(e)}), 500 + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = create_primary_user_id_or_link_accounts( + request.json["tenantId"], + recipe_user_id, + session, + request.json.get("userContext", None), + ) + return jsonify(response.to_json()) @app.route("/test/accountlinking/getprimaryuserthatcanbelinkedtorecipeuserid", methods=["POST"]) # type: ignore def get_primary_user_that_can_be_linked_to_recipe_user_id_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = get_primary_user_that_can_be_linked_to_recipe_user_id( - request.json["tenantId"], - recipe_user_id, - request.json.get("userContext", None), - ) - return jsonify(response.to_json() if response else None) - except Exception as e: - return jsonify({"error": str(e)}), 500 + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = get_primary_user_that_can_be_linked_to_recipe_user_id( + request.json["tenantId"], + recipe_user_id, + request.json.get("userContext", None), + ) + return jsonify(response.to_json() if response else None) @app.route("/test/accountlinking/issignupallowed", methods=["POST"]) # type: ignore def is_signup_allowed_api(): # type: ignore - try: - assert request.json is not None - response = is_sign_up_allowed( - request.json["tenantId"], - AccountInfoWithRecipeId( - recipe_id=request.json["newUser"]["recipeId"], - email=request.json["newUser"]["email"], - phone_number=request.json["newUser"]["phoneNumber"], - third_party=ThirdPartyInfo( - third_party_user_id=request.json["newUser"]["thirdParty"]["id"], - third_party_id=request.json["newUser"]["thirdParty"][ - "thirdPartyId" - ], - ), + assert request.json is not None + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_sign_up_allowed( + request.json["tenantId"], + AccountInfoWithRecipeId( + recipe_id=request.json["newUser"]["recipeId"], + email=request.json["newUser"]["email"], + phone_number=request.json["newUser"]["phoneNumber"], + third_party=ThirdPartyInfo( + third_party_user_id=request.json["newUser"]["thirdParty"]["id"], + third_party_id=request.json["newUser"]["thirdParty"][ + "thirdPartyId" + ], ), - request.json["isVerified"], - request.json.get("session", None), - request.json.get("userContext", None), - ) - return jsonify(response) - except Exception as e: - return jsonify({"error": str(e)}), 500 + ), + request.json["isVerified"], + session, + request.json.get("userContext", None), + ) + return jsonify(response) @app.route("/test/accountlinking/issigninallowed", methods=["POST"]) # type: ignore def is_signin_allowed_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = is_sign_in_allowed( - request.json["tenantId"], - recipe_user_id, - request.json.get("session", None), - request.json.get("userContext", None), - ) - return jsonify(response) - except Exception as e: - return jsonify({"error": str(e)}), 500 + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_sign_in_allowed( + request.json["tenantId"], + recipe_user_id, + session, + request.json.get("userContext", None), + ) + return jsonify(response) @app.route("/test/accountlinking/verifyemailforrecipeuseriflinkedaccountsareverified", methods=["POST"]) # type: ignore def verify_email_for_recipe_user_if_linked_accounts_are_verified_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - user = User.from_json(request.json["user"]) - async_to_sync_wrapper.sync( - AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( - user=user, - recipe_user_id=recipe_user_id, - user_context=request.json.get("userContext"), - ) + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + user = User.from_json(request.json["user"]) + async_to_sync_wrapper.sync( + AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=request.json.get("userContext"), ) - return jsonify({}) - except Exception as e: - return jsonify({"error": str(e)}), 500 + ) + return jsonify({}) @app.route("/test/accountlinking/cancreateprimaryuser", methods=["POST"]) # type: ignore def can_create_primary_user_api(): # type: ignore - try: - assert request.json is not None - recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) - response = can_create_primary_user( - recipe_user_id, request.json.get("userContext") + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = can_create_primary_user( + recipe_user_id, request.json.get("userContext") + ) + if isinstance(response, CanCreatePrimaryUserOkResult): + return jsonify( + { + "status": response.status, + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance(response, CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "status": response.status, + "primaryUserId": response.primary_user_id, + } ) - if isinstance(response, CanCreatePrimaryUserOkResult): - return jsonify( - { - "status": response.status, - "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, - } - ) - elif isinstance( - response, CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError - ): - return jsonify( - { - "description": response.description, - "primaryUserId": response.primary_user_id, - "status": response.status, - } - ) - else: - return jsonify( - { - "description": response.description, - "status": response.status, - "primaryUserId": response.primary_user_id, - } - ) - except Exception as e: - return jsonify({"error": str(e)}), 500 diff --git a/tests/test-server/emailpassword.py b/tests/test-server/emailpassword.py index 9691fa21c..553751853 100644 --- a/tests/test-server/emailpassword.py +++ b/tests/test-server/emailpassword.py @@ -10,6 +10,7 @@ ) import supertokens_python.recipe.emailpassword.syncio as emailpassword from session import convert_session_to_container # pylint: disable=import-error +from supertokens_python.types import RecipeUserId from utils import ( # pylint: disable=import-error serialize_user, serialize_recipe_user_id, @@ -26,9 +27,7 @@ def emailpassword_signup(): # type: ignore email = data["email"] password = data["password"] user_context = data.get("userContext") - session = ( - convert_session_to_container(data["session"]) if "session" in data else None - ) + session = convert_session_to_container(data) if "session" in data else None response = emailpassword.sign_up( tenant_id, email, password, session, user_context @@ -66,8 +65,11 @@ def emailpassword_signin(): # type: ignore email = data["email"] password = data["password"] user_context = data.get("userContext") + session = convert_session_to_container(data) if "session" in data else None - response = emailpassword.sign_in(tenant_id, email, password, user_context) + response = emailpassword.sign_in( + tenant_id, email, password, session, user_context + ) if isinstance(response, SignInOkResult): return jsonify( @@ -116,7 +118,7 @@ def emailpassword_update_email_or_password(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - user_id = data["userId"] + recipe_user_id = RecipeUserId(data["recipeUserId"]) email = data.get("email") password = data.get("password") apply_password_policy = data.get("applyPasswordPolicy") @@ -124,7 +126,7 @@ def emailpassword_update_email_or_password(): # type: ignore user_context = data.get("userContext") response = emailpassword.update_email_or_password( - user_id, + recipe_user_id, email, password, apply_password_policy, @@ -139,7 +141,9 @@ def emailpassword_update_email_or_password(): # type: ignore elif isinstance(response, EmailAlreadyExistsError): return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) elif isinstance(response, UpdateEmailOrPasswordEmailChangeNotAllowedError): - return jsonify({"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR"}) + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) else: return jsonify( { diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py index 9707ec247..791d394f3 100644 --- a/tests/test-server/emailverification.py +++ b/tests/test-server/emailverification.py @@ -1,5 +1,7 @@ from flask import Flask, request, jsonify +from supertokens_python import async_to_sync_wrapper +from supertokens_python.framework.flask.flask_request import FlaskRequest from supertokens_python.recipe.emailverification.interfaces import ( CreateEmailVerificationTokenOkResult, VerifyEmailUsingTokenOkResult, @@ -10,6 +12,22 @@ def add_emailverification_routes(app: Flask): + @app.route("/test/emailverification/isemailverified", methods=["POST"]) # type: ignore + def is_email_verified_api(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + from supertokens_python.recipe.emailverification.syncio import is_email_verified + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) + email = data.get("email") + user_context = data.get("userContext", {}) + + response = is_email_verified(recipe_user_id, email, user_context) + return jsonify(response) + @app.route("/test/emailverification/createemailverificationtoken", methods=["POST"]) # type: ignore def f(): # type: ignore from supertokens_python import convert_to_recipe_user_id @@ -65,3 +83,51 @@ def f2(): # type: ignore ) else: return jsonify({"status": "EMAIL_VERIFICATION_INVALID_TOKEN_ERROR"}) + + @app.route("/test/emailverification/unverifyemail", methods=["POST"]) # type: ignore + def unverify_email(): # type: ignore + from supertokens_python.recipe.emailverification.syncio import unverify_email + from supertokens_python.types import RecipeUserId + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = RecipeUserId(data["recipeUserId"]) + email = data.get("email") + user_context = data.get("userContext", {}) + + unverify_email(recipe_user_id, email, user_context) + return jsonify({"status": "OK"}) + + @app.route("/test/emailverification/updatesessionifrequiredpostemailverification", methods=["POST"]) # type: ignore + def update_session_if_required_post_email_verification(): # type: ignore + from supertokens_python.recipe.emailverification import EmailVerificationRecipe + from supertokens_python.types import RecipeUserId + from session import convert_session_to_container, convert_session_to_json + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id_whose_email_got_verified = RecipeUserId( + data["recipeUserIdWhoseEmailGotVerified"]["recipeUserId"] + ) + session = ( + convert_session_to_container(data["session"]) if "session" in data else None + ) + + try: + session_resp = async_to_sync_wrapper.sync( + EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, + session=session, + req=FlaskRequest(request), + user_context=data.get("userContext", {}), + ) + ) + return jsonify( + None if session_resp is None else convert_session_to_json(session_resp) + ) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 From c812d735270eb41e8437f2e76a77551048a7ec2b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 15:14:08 +0530 Subject: [PATCH 100/126] fixes stuff --- supertokens_python/process_state.py | 20 ++++++ .../recipe/accountlinking/interfaces.py | 14 ++++ .../recipe/accountlinking/types.py | 6 ++ .../recipe/emailpassword/interfaces.py | 7 ++ .../recipe/emailpassword/types.py | 5 +- .../recipe/session/access_token.py | 2 +- .../recipe/session/interfaces.py | 6 ++ supertokens_python/types.py | 13 ++++ tests/test-server/app.py | 23 ++++++- tests/test-server/override_logging.py | 54 ++++++++++++++- tests/test-server/session.py | 67 ++----------------- 11 files changed, 149 insertions(+), 68 deletions(-) diff --git a/supertokens_python/process_state.py b/supertokens_python/process_state.py index d7decdd2e..7a8c0f9ca 100644 --- a/supertokens_python/process_state.py +++ b/supertokens_python/process_state.py @@ -54,3 +54,23 @@ def get_event_by_last_event_by_name( if event == state: return event return None + + def wait_for_event( + self, state: PROCESS_STATE, time_in_ms: int = 7000 + ) -> Optional[PROCESS_STATE]: + from time import time, sleep + + start_time = time() + + def try_and_get() -> Optional[PROCESS_STATE]: + result = self.get_event_by_last_event_by_name(state) + if result is None: + if (time() - start_time) * 1000 > time_in_ms: + return None + else: + sleep(1) + return try_and_get() + else: + return result + + return try_and_get() diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 7ee318f95..0322024fa 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -155,6 +155,13 @@ def __init__(self, user: User, was_already_a_primary_user: bool): self.user = user self.was_already_a_primary_user = was_already_a_primary_user + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "user": self.user.to_json(), + "wasAlreadyAPrimaryUser": self.was_already_a_primary_user, + } + class CreatePrimaryUserRecipeUserIdAlreadyLinkedError: def __init__(self, primary_user_id: str, description: Optional[str] = None): @@ -216,6 +223,13 @@ def __init__(self, accounts_already_linked: bool, user: User): self.accounts_already_linked = accounts_already_linked self.user = user + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "accountsAlreadyLinked": self.accounts_already_linked, + "user": self.user.to_json(), + } + class LinkAccountsRecipeUserIdAlreadyLinkedError: def __init__( diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py index 922b559fb..0037c8614 100644 --- a/supertokens_python/recipe/accountlinking/types.py +++ b/supertokens_python/recipe/accountlinking/types.py @@ -43,6 +43,12 @@ def __init__( "emailpassword", "thirdparty", "passwordless" ] = recipe_id + def to_json(self) -> Dict[str, Any]: + return { + **super().to_json(), + "recipeId": self.recipe_id, + } + class RecipeLevelUser(AccountInfoWithRecipeId): def __init__( diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 49e5441ef..77d8c21f5 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -42,6 +42,13 @@ def __init__(self, user: User, recipe_user_id: RecipeUserId): self.user = user self.recipe_user_id = recipe_user_id + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } + class EmailAlreadyExistsError(APIResponse): status: str = "EMAIL_ALREADY_EXISTS_ERROR" diff --git a/supertokens_python/recipe/emailpassword/types.py b/supertokens_python/recipe/emailpassword/types.py index 55f8cbdaa..c9917b109 100644 --- a/supertokens_python/recipe/emailpassword/types.py +++ b/supertokens_python/recipe/emailpassword/types.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations -from typing import Awaitable, Callable, Optional, TypeVar, Union, Any +from typing import Awaitable, Callable, Dict, Optional, TypeVar, Union, Any from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( @@ -33,6 +33,9 @@ def __init__(self, id: str, value: Any): # pylint: disable=redefined-builtin self.id: str = id self.value: Any = value + def to_json(self) -> Dict[str, Any]: + return {"id": self.id, "value": self.value} + class InputFormField: def __init__( diff --git a/supertokens_python/recipe/session/access_token.py b/supertokens_python/recipe/session/access_token.py index ce4eec245..1fa03950d 100644 --- a/supertokens_python/recipe/session/access_token.py +++ b/supertokens_python/recipe/session/access_token.py @@ -101,7 +101,7 @@ def get_info_from_access_token( user_data = payload session_handle = sanitize_string(payload.get("sessionHandle")) - recipe_user_id = sanitize_string(payload.get("recipeUserId", user_id)) + recipe_user_id = sanitize_string(payload.get("rsub", user_id)) refresh_token_hash_1 = sanitize_string(payload.get("refreshTokenHash1")) parent_refresh_token_hash_1 = sanitize_string( payload.get("parentRefreshTokenHash1") diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 4075f8489..ac5ba49e0 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -135,6 +135,12 @@ def __init__( self.invalid_claims = invalid_claims self.access_token_payload_update = access_token_payload_update + def to_json(self) -> Dict[str, Any]: + return { + "invalidClaims": [i.to_json() for i in self.invalid_claims], + "accessTokenPayloadUpdate": self.access_token_payload_update, + } + class GetSessionTokensDangerouslyDict(TypedDict): accessToken: str diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 939868e71..0a016cb6d 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -48,6 +48,19 @@ def __init__( self.phone_number = phone_number self.third_party = third_party + def to_json(self) -> Dict[str, Any]: + json_repo: Dict[str, Any] = {} + if self.email is not None: + json_repo["email"] = self.email + if self.phone_number is not None: + json_repo["phoneNumber"] = self.phone_number + if self.third_party is not None: + json_repo["thirdParty"] = { + "id": self.third_party.id, + "userId": self.third_party.user_id, + } + return json_repo + class LoginMethod(AccountInfo): def __init__( diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 9515a90c7..f05877f13 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,6 +1,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify +from supertokens_python import process_state from supertokens_python.framework import BaseRequest from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig @@ -150,9 +151,10 @@ def builder(oI: T) -> T: def logging_override_func_sync(name: str, c: Any) -> Any: def inner(*args: Any, **kwargs: Any) -> Any: - override_logging.log_override_event( - name, "CALL", {"args": args, "kwargs": kwargs} - ) + if len(args) > 0: + override_logging.log_override_event(name, "CALL", args) + else: + override_logging.log_override_event(name, "CALL", kwargs) try: res = c(*args, **kwargs) override_logging.log_override_event(name, "RES", res) @@ -668,6 +670,21 @@ def verify_session_route(): return jsonify({"status": "OK"}) +@app.route("/test/waitforevent", methods=["GET"]) # type: ignore +def wait_for_event_api(): # type: ignore + event = request.args.get("event") + if not event: + raise ValueError("event query param missing") + + event_enum = process_state.PROCESS_STATE(int(event)) + instance = process_state.ProcessState.get_instance() + event_result = instance.wait_for_event(event_enum) + if event_result is None: + return jsonify(None) + else: + return jsonify("Found") + + @app.errorhandler(404) def not_found(error: Any) -> Any: # pylint: disable=unused-argument return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404 diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index f57c30570..08d169ee4 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -1,10 +1,29 @@ -from typing import Any, Dict, List, Set, Union +from typing import Any, Callable, Coroutine, Dict, List, Set, Union import time from httpx import Response from supertokens_python.framework.flask.flask_request import FlaskRequest -from supertokens_python.types import RecipeUserId +from supertokens_python.recipe.accountlinking.interfaces import ( + CreatePrimaryUserOkResult, + LinkAccountsOkResult, +) +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.emailpassword.types import FormField +from supertokens_python.recipe.emailpassword.interfaces import ( + APIOptions as EmailPasswordAPIOptions, + SignUpOkResult, + SignUpPostOkResult, +) +from supertokens_python.recipe.session.interfaces import ClaimsValidationResult +from supertokens_python.recipe.session.session_class import Session +from supertokens_python.recipe.thirdparty.interfaces import ( + APIOptions as ThirdPartyAPIOptions, +) +from supertokens_python.recipe.passwordless.interfaces import ( + APIOptions as PasswordlessAPIOptions, +) +from supertokens_python.types import AccountInfo, RecipeUserId, User override_logs: List[Dict[str, Any]] = [] @@ -43,5 +62,36 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return "Response" if isinstance(data, RecipeUserId): return data.get_as_string() + if isinstance(data, AccountInfoWithRecipeId): + return data.to_json() + if isinstance(data, AccountInfo): + return data.to_json() + if isinstance(data, User): + return data.to_json() + if isinstance(data, Coroutine): + return "Coroutine" + if isinstance(data, Callable): + return "Callable" + if isinstance(data, FormField): + return data.to_json() + if isinstance(data, EmailPasswordAPIOptions): + return "EmailPasswordAPIOptions" + if isinstance(data, ThirdPartyAPIOptions): + return "ThirdPartyAPIOptions" + if isinstance(data, PasswordlessAPIOptions): + return "PasswordlessAPIOptions" + if isinstance(data, SignUpOkResult): + return data.to_json() + if isinstance(data, CreatePrimaryUserOkResult): + return data.to_json() + if isinstance(data, LinkAccountsOkResult): + return data.to_json() + if isinstance(data, Session): + from session import convert_session_to_json + return convert_session_to_json(data) + if isinstance(data, SignUpPostOkResult): + return data.to_json() + if isinstance(data, ClaimsValidationResult): + return data.to_json() return data diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 20bc00807..3a33c6e37 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -89,25 +89,7 @@ def assert_claims(): # type: ignore return jsonify( { "status": "OK", - "updatedSession": { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - }, + "updatedSession": convert_session_to_json(session_container), } ) except Exception as e: @@ -134,26 +116,7 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore return jsonify( { "status": "OK", - "updatedSession": { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "recipeUserId": session_container.get_recipe_user_id().get_as_string(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - }, + "updatedSession": convert_session_to_json(session_container), } ) @@ -170,28 +133,7 @@ def fetch_and_set_claim_api(): # type: ignore user_context = data.get("userContext", {}) session.sync_fetch_and_set_claim(claim, user_context) - response = { - "updatedSession": { - "sessionHandle": session.get_handle(), - "userId": session.get_user_id(), - "recipeUserId": session.get_recipe_user_id().get_as_string(), - "tenantId": session.get_tenant_id(), - "userDataInAccessToken": session.get_access_token_payload(), - "accessToken": session.get_access_token(), - "frontToken": session.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - } - } + response = {"updatedSession": convert_session_to_json(session)} return jsonify(response) @@ -214,6 +156,9 @@ def convert_session_to_json(session_container: SessionContainer) -> Dict[str, An "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ "accessAndFrontTokenUpdated" ], + "recipeUserId": { + "recipeUserId": session_container.get_recipe_user_id().get_as_string() + }, } From 74ddd39e98b8c4119240a204449c533b9e1def2b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 18:31:42 +0530 Subject: [PATCH 101/126] fixes more tests --- .../recipe/emailpassword/interfaces.py | 2 +- .../recipe/emailpassword/recipe.py | 6 +- .../emailpassword/recipe_implementation.py | 10 +- tests/test-server/app.py | 10 +- tests/test-server/test_functions_mapper.py | 138 +++++++++++++++++- 5 files changed, 148 insertions(+), 18 deletions(-) diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 77d8c21f5..02510b947 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -257,7 +257,7 @@ def __init__(self, email: str, user: User): self.user = user def to_json(self) -> Dict[str, Any]: - return {"status": self.status, "email": self.email, "user": self.user.to_json()} + return {"status": self.status} class SignInPostOkResult(APIResponse): diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index b54627641..7ec8469e2 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -71,7 +71,6 @@ InputOverrideConfig, InputSignUpFeature, validate_and_normalise_user_input, - EmailPasswordConfig, ) @@ -97,11 +96,8 @@ def __init__( email_delivery, ) - def get_emailpassword_config() -> EmailPasswordConfig: - return self.config - recipe_implementation = RecipeImplementation( - Querier.get_instance(recipe_id), get_emailpassword_config + Querier.get_instance(recipe_id), self.config ) self.recipe_implementation = ( recipe_implementation diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index f056b1e3a..283569a5f 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Union, Callable +from typing import TYPE_CHECKING, Any, Dict, Union from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath @@ -52,11 +52,11 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - get_emailpassword_config: Callable[[], EmailPasswordConfig], + ep_config: EmailPasswordConfig, ): super().__init__() self.querier = querier - self.get_emailpassword_config = get_emailpassword_config + self.ep_config = ep_config async def sign_up( self, @@ -283,9 +283,7 @@ async def update_email_or_password( if password is not None: if apply_password_policy is None or apply_password_policy: - form_fields = ( - self.get_emailpassword_config().sign_up_feature.form_fields - ) + form_fields = self.ep_config.sign_up_feature.form_fields password_field = next( field for field in form_fields if field.id == FORM_FIELD_PASSWORD_ID ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index f05877f13..2962cb39a 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -233,17 +233,21 @@ def init_st(config: Dict[str, Any]): email_delivery=EmailDeliveryConfig( override=override_builder_with_logging( "EmailPassword.emailDelivery.override", - config.get("emailDelivery", {}).get("override", None), + recipe_config_json.get("emailDelivery", {}).get( + "override", None + ), ) ), override=emailpassword.InputOverrideConfig( apis=override_builder_with_logging( "EmailPassword.override.apis", - config.get("override", {}).get("apis", None), + recipe_config_json.get("override", {}).get("apis", None), ), functions=override_builder_with_logging( "EmailPassword.override.functions", - config.get("override", {}).get("functions", None), + recipe_config_json.get("override", {}).get( + "functions", None + ), ), ), ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index fcdd1644d..613f25996 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -7,13 +7,26 @@ ShouldNotAutomaticallyLink, ) from supertokens_python.recipe.dashboard.interfaces import APIOptions +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + PasswordPolicyViolationError, + PasswordResetPostOkResult, + PasswordResetTokenInvalidError, + SignUpPostNotAllowedResponse, + SignUpPostOkResult, +) +from supertokens_python.recipe.emailpassword.types import ( + EmailDeliveryOverrideInput, + EmailTemplateVars, + FormField, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.types import ( RawUserInfoFromProvider, UserInfo, UserInfoEmail, ) -from supertokens_python.types import AccountInfo, RecipeUserId +from supertokens_python.types import AccountInfo, GeneralErrorResponse, RecipeUserId from supertokens_python.types import APIResponse, User @@ -33,6 +46,49 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("emailpassword.init.emailDelivery.override"): + + def custom_email_deliver( + original_implementation: EmailDeliveryOverrideInput, + ) -> EmailDeliveryOverrideInput: + original_send_email = original_implementation.send_email + + async def send_email( + template_vars: EmailTemplateVars, user_context: Dict[str, Any] + ) -> None: + global send_email_callback_called # pylint: disable=global-variable-not-assigned + global send_email_to_user_id # pylint: disable=global-variable-not-assigned + global send_email_to_user_email # pylint: disable=global-variable-not-assigned + global send_email_to_recipe_user_id # pylint: disable=global-variable-not-assigned + global token # pylint: disable=global-variable-not-assigned + send_email_callback_called = True + + if template_vars.user: + send_email_to_user_id = template_vars.user.id + + if template_vars.user.email: + send_email_to_user_email = template_vars.user.email + + if template_vars.user.recipe_user_id: + send_email_to_recipe_user_id = ( + template_vars.user.recipe_user_id.get_as_string() + ) + + if template_vars.password_reset_link: + token = ( + template_vars.password_reset_link.split("?")[1] + .split("&")[0] + .split("=")[1] + ) + + # Use the original implementation which calls the default service, + # or a service that you may have specified in the email_delivery object. + return await original_send_email(template_vars, user_context) + + original_implementation.send_email = send_email + return original_implementation + + return custom_email_deliver elif eval_str.startswith("passwordless.init.emailDelivery.service.sendEmail"): def func1( @@ -116,6 +172,75 @@ async def consume_code_post( return func3 + elif eval_str.startswith("emailpassword.init.override.apis"): + from supertokens_python.recipe.emailpassword.interfaces import ( + APIInterface as EmailPasswordAPIInterface, + APIOptions as EmailPasswordAPIOptions, + ) + + def ep_override_apis( + original_implementation: EmailPasswordAPIInterface, + ) -> EmailPasswordAPIInterface: + + og_password_reset_post = original_implementation.password_reset_post + og_sign_up_post = original_implementation.sign_up_post + + async def password_reset_post( + form_fields: List[FormField], + token: str, + tenant_id: str, + api_options: EmailPasswordAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + PasswordResetPostOkResult, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, + GeneralErrorResponse, + ]: + if "DO_NOT_LINK" in eval_str: + user_context["DO_NOT_LINK"] = True + t = await og_password_reset_post( + form_fields, token, tenant_id, api_options, user_context + ) + if isinstance(t, PasswordResetPostOkResult): + global email_post_password_reset, user_post_password_reset + email_post_password_reset = t.email + user_post_password_reset = t.user + return t + + async def sign_up_post( + form_fields: List[FormField], + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + api_options: EmailPasswordAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, + ]: + if "signUpPOST" in eval_str: + n = await api_options.request.json() + assert n is not None + if n.get("user_context", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = n["user_context"]["DO_LINK"] + return await og_sign_up_post( + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, + ) + + original_implementation.password_reset_post = password_reset_post + original_implementation.sign_up_post = sign_up_post + return original_implementation + + return ep_override_apis + elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): async def func( @@ -417,7 +542,7 @@ def __init__( self.store = store def to_json(self) -> Dict[str, Any]: - return { + respon_json = { "sendEmailToUserId": self.send_email_to_user_id, "token": self.token, "userPostPasswordReset": ( @@ -430,7 +555,11 @@ def to_json(self) -> Dict[str, Any]: "sendEmailToUserEmail": self.send_email_to_user_email, "sendEmailInputs": self.send_email_inputs, "sendSmsInputs": self.send_sms_inputs, - "sendEmailToRecipeUserId": self.send_email_to_recipe_user_id, + "sendEmailToRecipeUserId": ( + {"recipeUserId": self.send_email_to_recipe_user_id} + if self.send_email_to_recipe_user_id is not None + else None + ), "userInCallback": ( self.user_in_callback.to_json() if self.user_in_callback is not None @@ -450,6 +579,9 @@ def to_json(self) -> Dict[str, Any]: }, "store": self.store, } + # Filter out items that are None + respon_json = {k: v for k, v in respon_json.items() if v is not None} + return respon_json def get_override_params() -> OverrideParams: From 45ff58f08a8cd4dfeb5b790e8daa90350c01fe97 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 20:01:51 +0530 Subject: [PATCH 102/126] fixes stuff --- .../recipe/emailpassword/interfaces.py | 6 +++ .../recipe/emailpassword/types.py | 18 ++++++++ .../recipe/emailverification/interfaces.py | 3 ++ .../recipe/emailverification/types.py | 8 +++- .../recipe/thirdparty/api/signinup.py | 7 ++-- supertokens_python/recipe/thirdparty/types.py | 16 +++++++ tests/test-server/app.py | 25 ++++++++--- tests/test-server/emailverification.py | 2 +- tests/test-server/override_logging.py | 42 ++++++++++++++++++- tests/test-server/test_functions_mapper.py | 27 ++++++------ 10 files changed, 129 insertions(+), 25 deletions(-) diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 02510b947..60fad3644 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -80,6 +80,12 @@ def __init__(self, email: str, user_id: str): self.email = email self.user_id = user_id + def to_json(self) -> Dict[str, Any]: + return { + "email": self.email, + "userId": self.user_id, + } + class PasswordResetTokenInvalidError(APIResponse): status: str = "RESET_PASSWORD_INVALID_TOKEN_ERROR" diff --git a/supertokens_python/recipe/emailpassword/types.py b/supertokens_python/recipe/emailpassword/types.py index c9917b109..e560f2e25 100644 --- a/supertokens_python/recipe/emailpassword/types.py +++ b/supertokens_python/recipe/emailpassword/types.py @@ -75,6 +75,17 @@ def __init__( self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "recipeUserId": ( + self.recipe_user_id.get_as_string() + if self.recipe_user_id is not None + else None + ), + "email": self.email, + } + class PasswordResetEmailTemplateVars: def __init__( @@ -87,6 +98,13 @@ def __init__( self.password_reset_link = password_reset_link self.tenant_id = tenant_id + def to_json(self) -> Dict[str, Any]: + return { + "user": self.user.to_json(), + "passwordResetLink": self.password_reset_link, + "tenantId": self.tenant_id, + } + # Export: EmailTemplateVars = PasswordResetEmailTemplateVars diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index f1501e357..b6d07c889 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -66,6 +66,9 @@ class VerifyEmailUsingTokenOkResult: def __init__(self, user: EmailVerificationUser): self.user = user + def to_json(self) -> Dict[str, Any]: + return {"user": self.user.to_json(), "status": self.status} + class VerifyEmailUsingTokenInvalidTokenError: pass diff --git a/supertokens_python/recipe/emailverification/types.py b/supertokens_python/recipe/emailverification/types.py index 118c7e70b..c8c5da69c 100644 --- a/supertokens_python/recipe/emailverification/types.py +++ b/supertokens_python/recipe/emailverification/types.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import Union +from typing import Any, Dict, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( @@ -28,6 +28,12 @@ def __init__(self, recipe_user_id: RecipeUserId, email: str): self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + return { + "recipeUserId": self.recipe_user_id.get_as_string(), + "email": self.email, + } + class VerificationEmailTemplateVarsUser: def __init__(self, _id: str, recipe_user_id: RecipeUserId, email: str): diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index 04b8daa19..bcd561a0c 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional from supertokens_python.recipe.thirdparty.interfaces import SignInUpPostOkResult from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo @@ -78,8 +78,9 @@ async def handle_sign_in_up_api( provider = provider_response + redirect_uri_info_parsed: Optional[RedirectUriInfo] = None if redirect_uri_info is not None: - redirect_uri_info = RedirectUriInfo( + redirect_uri_info_parsed = RedirectUriInfo( redirect_uri_on_provider_dashboard=redirect_uri_info.get( "redirectURIOnProviderDashboard" ), @@ -102,7 +103,7 @@ async def handle_sign_in_up_api( result = await api_implementation.sign_in_up_post( provider=provider, - redirect_uri_info=redirect_uri_info, + redirect_uri_info=redirect_uri_info_parsed, oauth_tokens=oauth_tokens, tenant_id=tenant_id, api_options=api_options, diff --git a/supertokens_python/recipe/thirdparty/types.py b/supertokens_python/recipe/thirdparty/types.py index 4a2dc7b9a..de7bfbedf 100644 --- a/supertokens_python/recipe/thirdparty/types.py +++ b/supertokens_python/recipe/thirdparty/types.py @@ -43,12 +43,21 @@ def __init__( self.from_id_token_payload = from_id_token_payload self.from_user_info_api = from_user_info_api + def to_json(self) -> Dict[str, Any]: + return { + "fromIdTokenPayload": self.from_id_token_payload, + "fromUserInfoApi": self.from_user_info_api, + } + class UserInfoEmail: def __init__(self, email: str, is_verified: bool): self.id: str = email self.is_verified: bool = is_verified + def to_json(self) -> Dict[str, Any]: + return {"id": self.id, "isVerified": self.is_verified} + class UserInfo: def __init__( @@ -63,6 +72,13 @@ def __init__( raw_user_info_from_provider or RawUserInfoFromProvider({}, {}) ) + def to_json(self) -> Dict[str, Any]: + return { + "thirdPartyUserId": self.third_party_user_id, + "email": self.email.to_json() if self.email is not None else None, + "rawUserInfoFromProvider": self.raw_user_info_from_provider.to_json(), + } + class AccessTokenAPI: def __init__(self, url: str, params: Dict[str, str]): diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 2962cb39a..1533f75fd 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -100,7 +100,15 @@ def origin_func( # pylint: disable=unused-argument, dangerous-default-value def toCamelCase(snake_case: str) -> str: components = snake_case.split("_") - return components[0] + "".join(x.title() for x in components[1:]) + res = components[0] + "".join(x.title() for x in components[1:]) + # Convert 'post', 'get', or 'put' at the end to uppercase + if res.endswith("Post"): + res = res[:-4] + "POST" + if res.endswith("Get"): + res = res[:-3] + "GET" + if res.endswith("Put"): + res = res[:-3] + "PUT" + return res def create_override( @@ -110,11 +118,16 @@ def create_override( originalFunction = getattr(implementation, functionName) async def finalFunction(*args: Any, **kwargs: Any): - override_logging.log_override_event( - name + "." + toCamelCase(functionName), - "CALL", - {"args": args, "kwargs": kwargs}, - ) + if len(args) > 0: + override_logging.log_override_event( + name + "." + toCamelCase(functionName), + "CALL", + args, + ) + else: + override_logging.log_override_event( + name + "." + toCamelCase(functionName), "CALL", kwargs + ) try: if inspect.iscoroutinefunction(originalFunction): res = await originalFunction(*args, **kwargs) diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py index 791d394f3..208c4a9f0 100644 --- a/tests/test-server/emailverification.py +++ b/tests/test-server/emailverification.py @@ -62,7 +62,7 @@ def f2(): # type: ignore tenant_id = data.get("tenantId", "public") token = data["token"] - attempt_account_linking = data.get("attemptAccountLinking", False) + attempt_account_linking = data.get("attemptAccountLinking", True) user_context = data.get("userContext", {}) response = verify_email_using_token( diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index 08d169ee4..2604922f7 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -9,11 +9,25 @@ LinkAccountsOkResult, ) from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId -from supertokens_python.recipe.emailpassword.types import FormField +from supertokens_python.recipe.emailpassword.types import ( + FormField, + PasswordResetEmailTemplateVars, +) from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EmailPasswordAPIOptions, + ConsumePasswordResetTokenOkResult, + CreateResetPasswordOkResult, + GeneratePasswordResetTokenPostOkResult, + PasswordResetPostOkResult, SignUpOkResult, SignUpPostOkResult, + UpdateEmailOrPasswordOkResult, +) +from supertokens_python.recipe.emailverification.interfaces import ( + CreateEmailVerificationTokenEmailAlreadyVerifiedError, + CreateEmailVerificationTokenOkResult, + GetEmailForUserIdOkResult, + VerifyEmailUsingTokenOkResult, ) from supertokens_python.recipe.session.interfaces import ClaimsValidationResult from supertokens_python.recipe.session.session_class import Session @@ -23,6 +37,8 @@ from supertokens_python.recipe.passwordless.interfaces import ( APIOptions as PasswordlessAPIOptions, ) +from supertokens_python.recipe.thirdparty.provider import ProviderConfigForClient +from supertokens_python.recipe.thirdparty.types import UserInfo as TPUserInfo from supertokens_python.types import AccountInfo, RecipeUserId, User override_logs: List[Dict[str, Any]] = [] @@ -94,4 +110,28 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return data.to_json() if isinstance(data, ClaimsValidationResult): return data.to_json() + if isinstance(data, ProviderConfigForClient): + return data.to_json() + if isinstance(data, TPUserInfo): + return data.to_json() + if isinstance(data, GeneratePasswordResetTokenPostOkResult): + return data.to_json() + if isinstance(data, CreateEmailVerificationTokenOkResult): + return {"token": data.token, "status": data.status} + if isinstance(data, GetEmailForUserIdOkResult): + return {"email": data.email, "status": "OK"} + if isinstance(data, VerifyEmailUsingTokenOkResult): + return {"status": data.status} + if isinstance(data, CreateResetPasswordOkResult): + return {"token": data.token, "status": "OK"} + if isinstance(data, PasswordResetEmailTemplateVars): + return data.to_json() + if isinstance(data, ConsumePasswordResetTokenOkResult): + return data.to_json() + if isinstance(data, UpdateEmailOrPasswordOkResult): + return {"status": "OK"} + if isinstance(data, CreateEmailVerificationTokenEmailAlreadyVerifiedError): + return {"status": "EMAIL_ALREADY_VERIFIED_ERROR"} + if isinstance(data, PasswordResetPostOkResult): + return {"status": "OK", "user": data.user.to_json(), "email": data.email} return data diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 613f25996..fb282501f 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -21,6 +21,7 @@ FormField, ) from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo from supertokens_python.recipe.thirdparty.types import ( RawUserInfoFromProvider, UserInfo, @@ -347,13 +348,13 @@ async def func( def custom_provider(provider: Any): if "custom-ev" in eval_str: - def exchange_auth_code_for_oauth_tokens1( - redirect_uri_info: Any, # pylint: disable=unused-argument + async def exchange_auth_code_for_oauth_tokens1( + redirect_uri_info: RedirectUriInfo, user_context: Any, # pylint: disable=unused-argument ) -> Any: - return {} + return redirect_uri_info.redirect_uri_query_params - def get_user_info1( + async def get_user_info1( oauth_tokens: Any, user_context: Any, # pylint: disable=unused-argument ): # pylint: disable=unused-argument @@ -377,13 +378,13 @@ def get_user_info1( if "custom-no-ev" in eval_str: - def exchange_auth_code_for_oauth_tokens2( + async def exchange_auth_code_for_oauth_tokens2( redirect_uri_info: Any, # pylint: disable=unused-argument user_context: Any, # pylint: disable=unused-argument ) -> Any: - return {} + return redirect_uri_info - def get_user_info2( + async def get_user_info2( oauth_tokens: Any, user_context: Any ): # pylint: disable=unused-argument return UserInfo( @@ -406,13 +407,13 @@ def get_user_info2( if "custom2" in eval_str: - def exchange_auth_code_for_oauth_tokens3( + async def exchange_auth_code_for_oauth_tokens3( redirect_uri_info: Any, user_context: Any, # pylint: disable=unused-argument ) -> Any: return redirect_uri_info["redirectURIQueryParams"] - def get_user_info3( + async def get_user_info3( oauth_tokens: Any, user_context: Any ): # pylint: disable=unused-argument return UserInfo( @@ -435,13 +436,13 @@ def get_user_info3( if "custom3" in eval_str: - def exchange_auth_code_for_oauth_tokens4( + async def exchange_auth_code_for_oauth_tokens4( redirect_uri_info: Any, user_context: Any, # pylint: disable=unused-argument ) -> Any: return redirect_uri_info["redirectURIQueryParams"] - def get_user_info4( + async def get_user_info4( oauth_tokens: Any, user_context: Any ): # pylint: disable=unused-argument return UserInfo( @@ -464,13 +465,13 @@ def get_user_info4( if "custom" in eval_str: - def exchange_auth_code_for_oauth_tokens5( + async def exchange_auth_code_for_oauth_tokens5( redirect_uri_info: Any, user_context: Any, # pylint: disable=unused-argument ) -> Any: return redirect_uri_info - def get_user_info5( + async def get_user_info5( oauth_tokens: Any, user_context: Any ): # pylint: disable=unused-argument if oauth_tokens.get("error"): From e50c851986074bd092c650b87bc4fd67462f041e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 20:17:30 +0530 Subject: [PATCH 103/126] fixes stuff --- tests/test-server/test_functions_mapper.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index fb282501f..c4f5d472c 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -379,10 +379,10 @@ async def get_user_info1( if "custom-no-ev" in eval_str: async def exchange_auth_code_for_oauth_tokens2( - redirect_uri_info: Any, # pylint: disable=unused-argument + redirect_uri_info: RedirectUriInfo, user_context: Any, # pylint: disable=unused-argument ) -> Any: - return redirect_uri_info + return redirect_uri_info.redirect_uri_query_params async def get_user_info2( oauth_tokens: Any, user_context: Any @@ -408,10 +408,10 @@ async def get_user_info2( if "custom2" in eval_str: async def exchange_auth_code_for_oauth_tokens3( - redirect_uri_info: Any, + redirect_uri_info: RedirectUriInfo, user_context: Any, # pylint: disable=unused-argument ) -> Any: - return redirect_uri_info["redirectURIQueryParams"] + return redirect_uri_info.redirect_uri_query_params async def get_user_info3( oauth_tokens: Any, user_context: Any @@ -437,10 +437,10 @@ async def get_user_info3( if "custom3" in eval_str: async def exchange_auth_code_for_oauth_tokens4( - redirect_uri_info: Any, + redirect_uri_info: RedirectUriInfo, user_context: Any, # pylint: disable=unused-argument ) -> Any: - return redirect_uri_info["redirectURIQueryParams"] + return redirect_uri_info.redirect_uri_query_params async def get_user_info4( oauth_tokens: Any, user_context: Any @@ -466,10 +466,10 @@ async def get_user_info4( if "custom" in eval_str: async def exchange_auth_code_for_oauth_tokens5( - redirect_uri_info: Any, + redirect_uri_info: RedirectUriInfo, user_context: Any, # pylint: disable=unused-argument ) -> Any: - return redirect_uri_info + return redirect_uri_info.redirect_uri_query_params async def get_user_info5( oauth_tokens: Any, user_context: Any From c1ab49793cff80dc767b8c66b352b0e9506fec9b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 11 Oct 2024 20:48:41 +0530 Subject: [PATCH 104/126] fixes more stuff --- tests/test-server/app.py | 38 ++++++++++++-- tests/test-server/emailverification.py | 29 +++++------ tests/test-server/supertokens.py | 2 +- tests/test-server/test_functions_mapper.py | 59 ++++++++++++++++++++++ 4 files changed, 107 insertions(+), 21 deletions(-) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 1533f75fd..87ce27673 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify from supertokens_python import process_state -from supertokens_python.framework import BaseRequest +from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig from supertokens_python.recipe import ( @@ -53,7 +53,7 @@ thirdparty, emailverification, ) -from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session from supertokens_python.recipe.thirdparty.provider import UserFields, UserInfoMap from supertokens_python.recipe_module import RecipeModule @@ -267,6 +267,14 @@ def init_st(config: Dict[str, Any]): ) elif recipe_id == "session": + + async def custom_unauthorised_callback( + _: BaseRequest, __: str, response: BaseResponse + ) -> BaseResponse: + response.set_status_code(401) + response.set_json_content(content={"type": "UNAUTHORISED"}) + return response + recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( session.init( @@ -302,6 +310,9 @@ def init_st(config: Dict[str, Any]): ), ), ), + error_handlers=InputErrorHandlers( + on_unauthorised=custom_unauthorised_callback + ), ) ) elif recipe_id == "accountlinking": @@ -437,7 +448,28 @@ def init_st(config: Dict[str, Any]): ev_config["override"] = emailverification.InputOverrideConfig( functions=override_functions ) - recipe_list.append(emailverification.init(**ev_config)) + from supertokens_python.recipe.emailverification.interfaces import ( + UnknownUserIdError, + ) + + recipe_list.append( + emailverification.init( + **ev_config, + get_email_for_recipe_user_id=callback_with_log( + "EmailVerification.getEmailForRecipeUserId", + recipe_config_json.get("getEmailForRecipeUserId"), + UnknownUserIdError(), + ), + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "EmailVerification.emailDelivery.override", + recipe_config_json.get("emailDelivery", {}).get( + "override", None + ), + ) + ), + ) + ) elif recipe_id == "multifactorauth": recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py index 208c4a9f0..220b87a9a 100644 --- a/tests/test-server/emailverification.py +++ b/tests/test-server/emailverification.py @@ -113,21 +113,16 @@ def update_session_if_required_post_email_verification(): # type: ignore recipe_user_id_whose_email_got_verified = RecipeUserId( data["recipeUserIdWhoseEmailGotVerified"]["recipeUserId"] ) - session = ( - convert_session_to_container(data["session"]) if "session" in data else None - ) - - try: - session_resp = async_to_sync_wrapper.sync( - EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( - recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, - session=session, - req=FlaskRequest(request), - user_context=data.get("userContext", {}), - ) - ) - return jsonify( - None if session_resp is None else convert_session_to_json(session_resp) + session = convert_session_to_container(data) if "session" in data else None + + session_resp = async_to_sync_wrapper.sync( + EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, + session=session, + req=FlaskRequest(request), + user_context=data.get("userContext", {}), ) - except Exception as e: - return jsonify({"status": "ERROR", "message": str(e)}), 500 + ) + return jsonify( + None if session_resp is None else convert_session_to_json(session_resp) + ) diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py index 3cc40d508..b7ecc752d 100644 --- a/tests/test-server/supertokens.py +++ b/tests/test-server/supertokens.py @@ -22,7 +22,7 @@ def delete_user_api(): # type: ignore assert request.json is not None delete_user( request.json["userId"], - request.json["removeAllLinkedAccounts"], + request.json.get("removeAllLinkedAccounts", True), request.json.get("userContext"), ) return jsonify({"status": "OK"}) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index c4f5d472c..07be9b243 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -14,12 +14,17 @@ PasswordResetTokenInvalidError, SignUpPostNotAllowedResponse, SignUpPostOkResult, + UnknownUserIdError, ) from supertokens_python.recipe.emailpassword.types import ( EmailDeliveryOverrideInput, EmailTemplateVars, FormField, ) +from supertokens_python.recipe.emailverification.interfaces import ( + EmailDoesNotExistError, + GetEmailForUserIdOkResult, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo from supertokens_python.recipe.thirdparty.types import ( @@ -47,6 +52,39 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("emailverification.init.emailDelivery.override"): + from supertokens_python.recipe.emailverification.types import ( + EmailDeliveryOverrideInput as EVEmailDeliveryOverrideInput, + EmailTemplateVars as EVEmailTemplateVars, + ) + + def custom_email_delivery_override( + original_implementation: EVEmailDeliveryOverrideInput, + ) -> EVEmailDeliveryOverrideInput: + original_send_email = original_implementation.send_email + + async def send_email( + template_vars: EVEmailTemplateVars, user_context: Dict[str, Any] + ) -> None: + global userInCallback # pylint: disable=global-variable-not-assigned + global token # pylint: disable=global-variable-not-assigned + + if template_vars.user: + userInCallback = template_vars.user + + if template_vars.email_verify_link: + token = template_vars.email_verify_link.split("?token=")[1].split( + "&tenantId=" + )[0] + + # Call the original implementation + await original_send_email(template_vars, user_context) + + original_implementation.send_email = send_email + return original_implementation + + return custom_email_delivery_override + elif eval_str.startswith("emailpassword.init.emailDelivery.override"): def custom_email_deliver( @@ -500,6 +538,27 @@ async def get_user_info5( return custom_provider + if eval_str.startswith("emailverification.init.getEmailForRecipeUserId"): + + async def get_email_for_recipe_user_id( + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], + ) -> Union[ + GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError + ]: + if "random@example.com" in eval_str: + return GetEmailForUserIdOkResult(email="random@example.com") + + if ( + hasattr(recipe_user_id, "get_as_string") + and recipe_user_id.get_as_string() == "random" + ): + return GetEmailForUserIdOkResult(email="test@example.com") + + return UnknownUserIdError() + + return get_email_for_recipe_user_id + raise Exception("Unknown eval string: " + eval_str) From 107a612e19153838b7afd42399ecd6dff6eff9ad Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 13:20:20 +0530 Subject: [PATCH 105/126] fixes stuff --- .../recipe/emailverification/types.py | 7 ++ tests/test-server/accountlinking.py | 26 ++-- tests/test-server/app.py | 30 +++-- tests/test-server/emailverification.py | 1 + tests/test-server/session.py | 51 ++++++++ tests/test-server/test_functions_mapper.py | 114 +++++++++++++++--- 6 files changed, 197 insertions(+), 32 deletions(-) diff --git a/supertokens_python/recipe/emailverification/types.py b/supertokens_python/recipe/emailverification/types.py index c8c5da69c..dc95cbb54 100644 --- a/supertokens_python/recipe/emailverification/types.py +++ b/supertokens_python/recipe/emailverification/types.py @@ -41,6 +41,13 @@ def __init__(self, _id: str, recipe_user_id: RecipeUserId, email: str): self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "email": self.email, + } + class VerificationEmailTemplateVars: def __init__( diff --git a/tests/test-server/accountlinking.py b/tests/test-server/accountlinking.py index 92fb0b330..6b3ec9a09 100644 --- a/tests/test-server/accountlinking.py +++ b/tests/test-server/accountlinking.py @@ -187,13 +187,25 @@ def is_signup_allowed_api(): # type: ignore request.json["tenantId"], AccountInfoWithRecipeId( recipe_id=request.json["newUser"]["recipeId"], - email=request.json["newUser"]["email"], - phone_number=request.json["newUser"]["phoneNumber"], - third_party=ThirdPartyInfo( - third_party_user_id=request.json["newUser"]["thirdParty"]["id"], - third_party_id=request.json["newUser"]["thirdParty"][ - "thirdPartyId" - ], + email=( + request.json["newUser"]["email"] + if "email" in request.json["newUser"] + else None + ), + phone_number=( + request.json["newUser"]["phoneNumber"] + if "phoneNumber" in request.json["newUser"] + else None + ), + third_party=( + ThirdPartyInfo( + third_party_user_id=request.json["newUser"]["thirdParty"]["id"], + third_party_id=request.json["newUser"]["thirdParty"][ + "thirdPartyId" + ], + ) + if "thirdParty" in request.json["newUser"] + else None ), ), request.json["isVerified"], diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 87ce27673..89fbe4652 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -436,17 +436,9 @@ async def custom_unauthorised_callback( elif recipe_id == "emailverification": recipe_config_json = json.loads(recipe_config.get("config", "{}")) - ev_config: Dict[str, Any] = {"mode": "OPTIONAL"} - if "mode" in recipe_config_json: - ev_config["mode"] = recipe_config_json["mode"] - override_functions = override_builder_with_logging( - "EmailVerification.override.functions", - config.get("override", {}).get("functions", None), - ) - - ev_config["override"] = emailverification.InputOverrideConfig( - functions=override_functions + from supertokens_python.recipe.emailverification.utils import ( + OverrideConfig as EmailVerificationOverrideConfig, ) from supertokens_python.recipe.emailverification.interfaces import ( UnknownUserIdError, @@ -454,7 +446,23 @@ async def custom_unauthorised_callback( recipe_list.append( emailverification.init( - **ev_config, + mode=( + recipe_config_json["mode"] + if "mode" in recipe_config_json + else "OPTIONAL" + ), + override=EmailVerificationOverrideConfig( + apis=override_builder_with_logging( + "EmailVerification.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "EmailVerification.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), get_email_for_recipe_user_id=callback_with_log( "EmailVerification.getEmailForRecipeUserId", recipe_config_json.get("getEmailForRecipeUserId"), diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py index 220b87a9a..e47660f18 100644 --- a/tests/test-server/emailverification.py +++ b/tests/test-server/emailverification.py @@ -76,6 +76,7 @@ def f2(): # type: ignore "user": { "email": response.user.email, "recipeUserId": { + # this is intentionally done this way cause the test in the test suite expects this way. "recipeUserId": response.user.recipe_user_id.get_as_string() }, }, diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 3a33c6e37..b29468db8 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -120,6 +120,31 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/getsessioninformation", methods=["POST"]) # type: ignore + def get_session_information_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + user_context = data.get("userContext", {}) + + response = session.get_session_information(session_handle, user_context) + if response is None: + return jsonify(None) + return jsonify( + { + "customClaimsInAccessTokenPayload": response.custom_claims_in_access_token_payload, + "sessionDataInDatabase": response.session_data_in_database, + "expiry": response.expiry, + "sessionHandle": response.session_handle, + "recipeUserId": response.recipe_user_id.get_as_string(), + "tenantId": response.tenant_id, + "timeCreated": response.time_created, + "userId": response.user_id, + } + ) + @app.route("/test/session/sessionobject/fetchandsetclaim", methods=["POST"]) # type: ignore def fetch_and_set_claim_api(): # type: ignore data = request.json @@ -136,6 +161,30 @@ def fetch_and_set_claim_api(): # type: ignore response = {"updatedSession": convert_session_to_json(session)} return jsonify(response) + @app.route("/test/session/sessionobject/getclaimvalue", methods=["POST"]) # type: ignore + def get_claim_value_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.getclaimvalue", "CALL", data) + session = convert_session_to_container(data) + + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + try: + ret_val = session.sync_get_claim_value(claim, user_context) + response = { + "retVal": ret_val, + "updatedSession": convert_session_to_json(session), + } + log_override_event("sessionobject.getclaimvalue", "RES", ret_val) + return jsonify(response) + except Exception as e: + log_override_event("sessionobject.getclaimvalue", "REJ", str(e)) + raise e + def convert_session_to_json(session_container: SessionContainer) -> Dict[str, Any]: return { @@ -157,6 +206,7 @@ def convert_session_to_json(session_container: SessionContainer) -> Dict[str, An "accessAndFrontTokenUpdated" ], "recipeUserId": { + # this is intentionally done this way cause the test in the test suite expects this way. "recipeUserId": session_container.get_recipe_user_id().get_as_string() }, } @@ -197,6 +247,7 @@ def convert_session_to_container(data: Any) -> Session: ), ) if "refreshToken" in data["session"] + and data["session"]["refreshToken"] is not None else None ), anti_csrf_token=anti_csrf_token, diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 07be9b243..211275f38 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -14,7 +14,6 @@ PasswordResetTokenInvalidError, SignUpPostNotAllowedResponse, SignUpPostOkResult, - UnknownUserIdError, ) from supertokens_python.recipe.emailpassword.types import ( EmailDeliveryOverrideInput, @@ -25,8 +24,16 @@ EmailDoesNotExistError, GetEmailForUserIdOkResult, ) +from supertokens_python.recipe.emailverification.types import ( + VerificationEmailTemplateVarsUser, +) from supertokens_python.recipe.session import SessionContainer -from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo +from supertokens_python.recipe.thirdparty.interfaces import ( + SignInUpNotAllowed, + SignInUpPostNoEmailGivenByProviderResponse, + SignInUpPostOkResult, +) +from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo from supertokens_python.recipe.thirdparty.types import ( RawUserInfoFromProvider, UserInfo, @@ -66,11 +73,11 @@ def custom_email_delivery_override( async def send_email( template_vars: EVEmailTemplateVars, user_context: Dict[str, Any] ) -> None: - global userInCallback # pylint: disable=global-variable-not-assigned + global user_in_callback # pylint: disable=global-variable-not-assigned global token # pylint: disable=global-variable-not-assigned if template_vars.user: - userInCallback = template_vars.user + user_in_callback = template_vars.user if template_vars.email_verify_link: token = template_vars.email_verify_link.split("?token=")[1].split( @@ -280,6 +287,81 @@ async def sign_up_post( return ep_override_apis + elif eval_str.startswith("emailverification.init.override.functions"): + from supertokens_python.recipe.emailverification.interfaces import ( + RecipeInterface as EmailVerificationRecipeInterface, + ) + + def ev_override_functions( + original_implementation: EmailVerificationRecipeInterface, + ) -> EmailVerificationRecipeInterface: + og_is_email_verified = original_implementation.is_email_verified + + async def is_email_verified( + recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] + ) -> bool: + global email_param + email_param = email + return await og_is_email_verified(recipe_user_id, email, user_context) + + original_implementation.is_email_verified = is_email_verified + return original_implementation + + return ev_override_functions + + elif eval_str.startswith("thirdparty.init.override.apis"): + from supertokens_python.recipe.thirdparty.interfaces import ( + APIInterface as ThirdPartyAPIInterface, + APIOptions as ThirdPartyAPIOptions, + ) + + def tp_override_apis( + original_implementation: ThirdPartyAPIInterface, + ) -> ThirdPartyAPIInterface: + async def sign_in_up_post( + provider: Provider, + redirect_uri_info: Optional[RedirectUriInfo], + oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + api_options: ThirdPartyAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + SignInUpPostOkResult, + SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, + GeneralErrorResponse, + ]: + json_body = await api_options.request.json() + if ( + json_body is not None + and json_body.get("userContext", {}).get("DO_LINK") is not None + ): + user_context["DO_LINK"] = json_body["userContext"]["DO_LINK"] + + result = await original_implementation.sign_in_up_post( + provider, + redirect_uri_info, + oauth_tokens, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, + ) + + if isinstance(result, SignInUpPostOkResult): + global user_in_callback + user_in_callback = result.user + + return result + + original_implementation.sign_in_up_post = sign_in_up_post + return original_implementation + + return tp_override_apis + elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): async def func( @@ -539,13 +621,14 @@ async def get_user_info5( return custom_provider if eval_str.startswith("emailverification.init.getEmailForRecipeUserId"): + from supertokens_python.recipe.emailverification.interfaces import ( + UnknownUserIdError as EVUnknownUserId, + ) async def get_email_for_recipe_user_id( recipe_user_id: RecipeUserId, user_context: Dict[str, Any], - ) -> Union[ - GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError - ]: + ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, EVUnknownUserId]: if "random@example.com" in eval_str: return GetEmailForUserIdOkResult(email="random@example.com") @@ -555,7 +638,7 @@ async def get_email_for_recipe_user_id( ): return GetEmailForUserIdOkResult(email="test@example.com") - return UnknownUserIdError() + return EVUnknownUserId() return get_email_for_recipe_user_id @@ -574,7 +657,9 @@ def __init__( send_email_inputs: Optional[List[str]] = None, send_sms_inputs: List[str] = [], # pylint: disable=dangerous-default-value send_email_to_recipe_user_id: Optional[str] = None, - user_in_callback: Optional[User] = None, + user_in_callback: Optional[ + Union[User, VerificationEmailTemplateVarsUser] + ] = None, email: Optional[str] = None, new_account_info_in_callback: Optional[RecipeLevelUser] = None, primary_user_in_callback: Optional[User] = None, @@ -616,6 +701,7 @@ def to_json(self) -> Dict[str, Any]: "sendEmailInputs": self.send_email_inputs, "sendSmsInputs": self.send_sms_inputs, "sendEmailToRecipeUserId": ( + # this is intentionally done this way cause the test in the test suite expects this way. {"recipeUserId": self.send_email_to_recipe_user_id} if self.send_email_to_recipe_user_id is not None else None @@ -656,7 +742,7 @@ def get_override_params() -> OverrideParams: send_sms_inputs=send_sms_inputs, send_email_to_recipe_user_id=send_email_to_recipe_user_id, user_in_callback=user_in_callback, - email=email, + email=email_param, new_account_info_in_callback=new_account_info_in_callback, primary_user_in_callback=( primary_user_in_callback if primary_user_in_callback else None @@ -673,7 +759,7 @@ def get_override_params() -> OverrideParams: def reset_override_params(): - global send_email_to_user_id, token, user_post_password_reset, email_post_password_reset, send_email_callback_called, send_email_to_user_email, send_email_inputs, send_sms_inputs, send_email_to_recipe_user_id, user_in_callback, email, primary_user_in_callback, new_account_info_in_callback, user_id_in_callback, recipe_user_id_in_callback, store + global send_email_to_user_id, token, user_post_password_reset, email_post_password_reset, send_email_callback_called, send_email_to_user_email, send_email_inputs, send_sms_inputs, send_email_to_recipe_user_id, user_in_callback, email_param, primary_user_in_callback, new_account_info_in_callback, user_id_in_callback, recipe_user_id_in_callback, store send_email_to_user_id = None token = None user_post_password_reset = None @@ -684,7 +770,7 @@ def reset_override_params(): send_sms_inputs = [] send_email_to_recipe_user_id = None user_in_callback = None - email = None + email_param = None primary_user_in_callback = None new_account_info_in_callback = None user_id_in_callback = None @@ -702,8 +788,8 @@ def reset_override_params(): send_email_inputs: List[Any] = [] send_sms_inputs: List[Any] = [] send_email_to_recipe_user_id: Optional[str] = None -user_in_callback: Optional[User] = None -email: Optional[str] = None +user_in_callback: Optional[Union[User, VerificationEmailTemplateVarsUser]] = None +email_param: Optional[str] = None primary_user_in_callback: Optional[User] = None new_account_info_in_callback: Optional[RecipeLevelUser] = None user_id_in_callback: Optional[str] = None From ff089d88e0fcdbf5cf76a0b0a7a19bf58ef53068 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 13:46:20 +0530 Subject: [PATCH 106/126] fixes stuff --- supertokens_python/process_state.py | 18 +++++++++--------- supertokens_python/recipe/thirdparty/types.py | 3 +++ supertokens_python/types.py | 14 ++++++++++---- tests/test-server/supertokens.py | 8 +++----- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/supertokens_python/process_state.py b/supertokens_python/process_state.py index 7a8c0f9ca..83429697c 100644 --- a/supertokens_python/process_state.py +++ b/supertokens_python/process_state.py @@ -17,15 +17,15 @@ class PROCESS_STATE(Enum): - CALLING_SERVICE_IN_VERIFY = 1 - CALLING_SERVICE_IN_GET_API_VERSION = 2 - CALLING_SERVICE_IN_REQUEST_HELPER = 3 - MULTI_JWKS_VALIDATION = 4 - IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS = 5 - IS_SIGN_UP_ALLOWED_CALLED = 6 - IS_SIGN_IN_ALLOWED_CALLED = 7 - IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED = 8 - ADDING_NO_CACHE_HEADER_IN_FETCH = 9 + CALLING_SERVICE_IN_VERIFY = 0 + CALLING_SERVICE_IN_GET_API_VERSION = 1 + CALLING_SERVICE_IN_REQUEST_HELPER = 2 + MULTI_JWKS_VALIDATION = 3 + IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS = 4 + IS_SIGN_UP_ALLOWED_CALLED = 5 + IS_SIGN_IN_ALLOWED_CALLED = 6 + IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED = 7 + ADDING_NO_CACHE_HEADER_IN_FETCH = 8 class ProcessState: diff --git a/supertokens_python/recipe/thirdparty/types.py b/supertokens_python/recipe/thirdparty/types.py index de7bfbedf..3e8ec7037 100644 --- a/supertokens_python/recipe/thirdparty/types.py +++ b/supertokens_python/recipe/thirdparty/types.py @@ -33,6 +33,9 @@ def __eq__(self, other: object) -> bool: and self.id == other.id ) + def to_json(self) -> Dict[str, Any]: + return {"userId": self.user_id, "id": self.id} + class RawUserInfoFromProvider: def __init__( diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 0a016cb6d..4e9b8ecd6 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -137,7 +137,7 @@ def to_json(self) -> Dict[str, Any]: "tenantIds": self.tenant_ids, "email": self.email, "phoneNumber": self.phone_number, - "thirdParty": self.third_party.__dict__ if self.third_party else None, + "thirdParty": self.third_party.to_json() if self.third_party else None, "timeJoined": self.time_joined, "verified": self.verified, } @@ -150,12 +150,18 @@ def from_json(json: Dict[str, Any]) -> "LoginMethod": recipe_id=json["recipeId"], recipe_user_id=json["recipeUserId"], tenant_ids=json["tenantIds"], - email=json["email"] if "email" in json else None, - phone_number=json["phoneNumber"] if "phoneNumber" in json else None, + email=( + json["email"] if "email" in json and json["email"] is not None else None + ), + phone_number=( + json["phoneNumber"] + if "phoneNumber" in json and json["phoneNumber"] is not None + else None + ), third_party=( ( TPI(json["thirdParty"]["userId"], json["thirdParty"]["id"]) - if "thirdParty" in json + if "thirdParty" in json and json["thirdParty"] is not None else None ) ), diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py index b7ecc752d..449427a24 100644 --- a/tests/test-server/supertokens.py +++ b/tests/test-server/supertokens.py @@ -39,16 +39,14 @@ def list_users_by_account_info_api(): # type: ignore None if "thirdParty" not in request.json["accountInfo"] else ThirdPartyInfo( - third_party_id=request.json["accountInfo"]["thirdParty"][ - "thirdPartyId" - ], + third_party_id=request.json["accountInfo"]["thirdParty"]["id"], third_party_user_id=request.json["accountInfo"]["thirdParty"][ - "id" + "userId" ], ) ), ), - request.json["doUnionOfAccountInfo"], + request.json.get("doUnionOfAccountInfo", False), request.json.get("userContext"), ) From 9e0406a7c888a3a5d959507511cbb181d692468b Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 14:13:24 +0530 Subject: [PATCH 107/126] fixes more stuff --- tests/test-server/multitenancy.py | 5 +++-- tests/test-server/passwordless.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index 4da9df85a..b617722f5 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -10,6 +10,7 @@ ProviderInput, ) from supertokens_python.recipe.thirdparty.provider import UserFields, UserInfoMap +from supertokens_python.types import RecipeUserId def add_multitenancy_routes(app: Flask): @@ -177,11 +178,11 @@ def associate_user_to_tenant(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data["tenantId"] - user_id = data["userId"] + recipe_user_id = RecipeUserId(data["recipeUserId"]) user_context = data.get("userContext") response = multitenancy.associate_user_to_tenant( - tenant_id, user_id, user_context + tenant_id, recipe_user_id, user_context ) if isinstance(response, AssociateUserToTenantOkResult): diff --git a/tests/test-server/passwordless.py b/tests/test-server/passwordless.py index 4edb510f5..339d68028 100644 --- a/tests/test-server/passwordless.py +++ b/tests/test-server/passwordless.py @@ -158,6 +158,13 @@ def update_user_api(): # type: ignore elif isinstance(response, UpdateUserPhoneNumberAlreadyExistsError): return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) elif isinstance(response, EmailChangeNotAllowedError): - return jsonify({"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR"}) + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) else: - return jsonify({"status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR"}) + return jsonify( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": response.reason, + } + ) From 36342334c02b54f19b7d17a22e9c01b42032eacf Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 14:37:41 +0530 Subject: [PATCH 108/126] fixes stuff --- .../recipe/accountlinking/interfaces.py | 3 +-- .../accountlinking/recipe_implementation.py | 4 +--- tests/test-server/accountlinking.py | 1 - tests/test-server/app.py | 2 +- tests/test-server/override_logging.py | 3 +++ tests/test-server/passwordless.py | 7 ++++--- tests/test-server/session.py | 18 ++++++++++++++++++ tests/test-server/test_functions_mapper.py | 18 +++++++++++++++++- 8 files changed, 45 insertions(+), 11 deletions(-) diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py index 0322024fa..2058765d3 100644 --- a/supertokens_python/recipe/accountlinking/interfaces.py +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -260,11 +260,10 @@ def __init__( class LinkAccountsInputUserNotPrimaryError: - def __init__(self, description: Optional[str] = None): + def __init__(self): self.status: Literal[ "INPUT_USER_IS_NOT_A_PRIMARY_USER" ] = "INPUT_USER_IS_NOT_A_PRIMARY_USER" - self.description = description class UnlinkAccountOkResult: diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py index 5d4495c5b..c14cbf22c 100644 --- a/supertokens_python/recipe/accountlinking/recipe_implementation.py +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -290,9 +290,7 @@ async def link_accounts( description=response["description"], ) elif response["status"] == "INPUT_USER_IS_NOT_A_PRIMARY_USER": - return LinkAccountsInputUserNotPrimaryError( - description=response["description"], - ) + return LinkAccountsInputUserNotPrimaryError() else: raise Exception(f"Unknown response status: {response['status']}") diff --git a/tests/test-server/accountlinking.py b/tests/test-server/accountlinking.py index 6b3ec9a09..30b73b694 100644 --- a/tests/test-server/accountlinking.py +++ b/tests/test-server/accountlinking.py @@ -114,7 +114,6 @@ def link_accounts_api(): # type: ignore else: return jsonify( { - "description": response.description, "status": response.status, } ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 89fbe4652..da08cf5d0 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -326,7 +326,7 @@ async def custom_unauthorised_callback( ), on_account_linked=callback_with_log( "AccountLinking.onAccountLinked", - recipe_config_json.get("on_account_linked"), + recipe_config_json.get("onAccountLinked"), ), override=accountlinking.InputOverrideConfig( functions=override_builder_with_logging( diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index 2604922f7..2bb2d5a21 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -4,6 +4,7 @@ from httpx import Response from supertokens_python.framework.flask.flask_request import FlaskRequest +from supertokens_python.recipe.accountlinking import RecipeLevelUser from supertokens_python.recipe.accountlinking.interfaces import ( CreatePrimaryUserOkResult, LinkAccountsOkResult, @@ -134,4 +135,6 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return {"status": "EMAIL_ALREADY_VERIFIED_ERROR"} if isinstance(data, PasswordResetPostOkResult): return {"status": "OK", "user": data.user.to_json(), "email": data.email} + if isinstance(data, RecipeLevelUser): + return data.to_json() return data diff --git a/tests/test-server/passwordless.py b/tests/test-server/passwordless.py index 339d68028..eb88607d2 100644 --- a/tests/test-server/passwordless.py +++ b/tests/test-server/passwordless.py @@ -70,6 +70,7 @@ def create_code_api(): # type: ignore ) return jsonify( { + "status": "OK", "codeId": response.code_id, "preAuthSessionId": response.pre_auth_session_id, "codeLifeTime": response.code_life_time, @@ -89,10 +90,10 @@ def consume_code_api(): # type: ignore session = convert_session_to_container(body) response = consume_code( - device_id=body["deviceId"], - pre_auth_session_id=body["preAuthSessionId"], + device_id=body.get("deviceId"), + pre_auth_session_id=body.get("preAuthSessionId"), user_input_code=body.get("userInputCode"), - link_code=body["linkCode"], + link_code=body.get("linkCode", None), tenant_id=body.get("tenantId", "public"), user_context=body.get("userContext"), session=session, diff --git a/tests/test-server/session.py b/tests/test-server/session.py index b29468db8..ca1e40968 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -53,6 +53,24 @@ def create_new_session_without_request_response(): # type: ignore return jsonify(convert_session_to_json(session_container)) + @app.route("/test/session/getallsessionhandlesforuser", methods=["POST"]) # type: ignore + def get_all_session_handles_for_user_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + user_id = data["userId"] + fetch_sessions_for_all_linked_accounts = data.get( + "fetchSessionsForAllLinkedAccounts", True + ) + tenant_id = data.get("tenantId", "public") + user_context = data.get("userContext", {}) + + response = session.get_all_session_handles_for_user( + user_id, fetch_sessions_for_all_linked_accounts, tenant_id, user_context + ) + return jsonify(response) + @app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore def get_session_without_request_response(): # type: ignore data = request.json diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 211275f38..cc0c80aef 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -59,6 +59,18 @@ def func(*args): # type: ignore return func # type: ignore + elif eval_str.startswith("accountlinking.init.onAccountLinked"): + + async def on_account_linked( + user: User, recipe_level_user: RecipeLevelUser, user_context: Dict[str, Any] + ) -> None: + global primary_user_in_callback + global new_account_info_in_callback + primary_user_in_callback = user + new_account_info_in_callback = recipe_level_user + + return on_account_linked + elif eval_str.startswith("emailverification.init.emailDelivery.override"): from supertokens_python.recipe.emailverification.types import ( EmailDeliveryOverrideInput as EVEmailDeliveryOverrideInput, @@ -712,7 +724,11 @@ def to_json(self) -> Dict[str, Any]: else None ), "email": self.email, - "newAccountInfoInCallback": self.new_account_info_in_callback, + "newAccountInfoInCallback": ( + self.new_account_info_in_callback.to_json() + if self.new_account_info_in_callback is not None + else None + ), "primaryUserInCallback": ( self.primary_user_in_callback.to_json() if self.primary_user_in_callback is not None From ad7c89e1480fc27bdfded688110a5029aa31fa5c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 15:31:19 +0530 Subject: [PATCH 109/126] fixes stuff --- supertokens_python/querier.py | 4 +- .../recipe/session/interfaces.py | 44 ++++++ tests/test-server/app.py | 43 ++++-- tests/test-server/override_logging.py | 13 +- tests/test-server/session.py | 130 +++++++++++++++++- tests/test-server/test_functions_mapper.py | 61 ++++++++ 6 files changed, 274 insertions(+), 21 deletions(-) diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index c330a56ea..69945493d 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -523,9 +523,9 @@ async def __send_request_helper( raise Exception( "SuperTokens core threw an error for a " + method - + " request to path: " + + " request to path: '" + path.get_as_string_dangerous() - + " with status code: " + + "' with status code: " + str(response.status_code) + " and message: " + response.text # type: ignore diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index ac5ba49e0..0ffe24518 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -60,6 +60,15 @@ def __init__( self.tenant_id = tenant_id self.recipe_user_id = recipe_user_id + def to_json(self) -> Dict[str, Any]: + return { + "handle": self.handle, + "userId": self.user_id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "tenantId": self.tenant_id, + "userDataInJWT": self.user_data_in_jwt, + } + class AccessTokenObj: def __init__(self, token: str, expiry: int, created_time: int): @@ -67,12 +76,27 @@ def __init__(self, token: str, expiry: int, created_time: int): self.expiry = expiry self.created_time = created_time + def to_json(self) -> Dict[str, Any]: + return { + "token": self.token, + "expiry": self.expiry, + "createdTime": self.created_time, + } + class RegenerateAccessTokenOkResult: def __init__(self, session: SessionObj, access_token: Union[AccessTokenObj, None]): self.session = session self.access_token = access_token + def to_json(self) -> Dict[str, Any]: + return { + "session": self.session.to_json(), + "accessToken": ( + self.access_token.to_json() if self.access_token is not None else None + ), + } + class SessionInformationResult: def __init__( @@ -97,6 +121,18 @@ def __init__( self.tenant_id = tenant_id self.recipe_user_id = recipe_user_id + def to_json(self) -> Dict[str, Any]: + return { + "sessionHandle": self.session_handle, + "userId": self.user_id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "sessionDataInDatabase": self.session_data_in_database, + "expiry": self.expiry, + "customClaimsInAccessTokenPayload": self.custom_claims_in_access_token_payload, + "timeCreated": self.time_created, + "tenantId": self.tenant_id, + } + class ReqResInfo: def __init__( @@ -137,6 +173,7 @@ def __init__( def to_json(self) -> Dict[str, Any]: return { + "status": "OK", "invalidClaims": [i.to_json() for i in self.invalid_claims], "accessTokenPayloadUpdate": self.access_token_payload_update, } @@ -397,6 +434,13 @@ def __init__(self, token: str, expiry: int, created_time: int): self.expiry = expiry self.created_time = created_time + def to_json(self) -> Dict[str, Any]: + return { + "token": self.token, + "expiry": self.expiry, + "createdTime": self.created_time, + } + class SessionContainer(ABC): # pylint: disable=too-many-public-methods def __init__( diff --git a/tests/test-server/app.py b/tests/test-server/app.py index da08cf5d0..ed406eb1f 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -26,6 +26,7 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import RecipeUserId from test_functions_mapper import ( # pylint: disable=import-error get_func, get_override_params, @@ -696,29 +697,45 @@ def mock_external_api(): return jsonify({"ok": True}) -# @app.route("/create", methods=["POST"]) -# def create_session(): -# recipe_user_id = request.json.get("recipeUserId") +@app.route("/create", methods=["POST"]) # type: ignore +def create_session_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + recipe_user_id = RecipeUserId(data.get("recipeUserId")) -# session = session.create_new_session(request, "public", recipe_user_id) -# return jsonify({"status": "OK"}) + from supertokens_python.recipe.session.syncio import create_new_session + + create_new_session(request, "public", recipe_user_id) + return jsonify({"status": "OK"}) @app.route("/getsession", methods=["POST"]) # type: ignore @verify_session() def get_session(): - session: SessionContainer = request.environ["session"] + from supertokens_python.recipe.session.syncio import get_session + + session = get_session(request) + assert session is not None return jsonify( - {"userId": session.get_user_id(), "recipeUserId": session.get_user_id()} + { + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + } ) -# @app.route("/refreshsession", methods=["POST"]) -# def refresh_session(): -# session: SessionContainer = session.refresh_session(request) -# return jsonify( -# {"userId": session.get_user_id(), "recipeUserId": session.get_user_id()} -# ) +@app.route("/refreshsession", methods=["POST"]) # type: ignore +def refresh_session_api(): # type: ignore + from supertokens_python.recipe.session.syncio import refresh_session + + session: SessionContainer = refresh_session(request) + return jsonify( + { + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + } + ) @app.route("/verify", methods=["GET"]) # type: ignore diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index 2bb2d5a21..307ec083f 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -30,7 +30,12 @@ GetEmailForUserIdOkResult, VerifyEmailUsingTokenOkResult, ) -from supertokens_python.recipe.session.interfaces import ClaimsValidationResult +from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.recipe.session.interfaces import ( + ClaimsValidationResult, + RegenerateAccessTokenOkResult, + SessionInformationResult, +) from supertokens_python.recipe.session.session_class import Session from supertokens_python.recipe.thirdparty.interfaces import ( APIOptions as ThirdPartyAPIOptions, @@ -137,4 +142,10 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return {"status": "OK", "user": data.user.to_json(), "email": data.email} if isinstance(data, RecipeLevelUser): return data.to_json() + if isinstance(data, RegenerateAccessTokenOkResult): + return data.to_json() + if isinstance(data, PrimitiveClaim): + return "PrimitiveClaim" + if isinstance(data, SessionInformationResult): + return data.to_json() return data diff --git a/tests/test-server/session.py b/tests/test-server/session.py index ca1e40968..cf1dc174e 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -2,7 +2,11 @@ from flask import Flask, request, jsonify from override_logging import log_override_event # pylint: disable=import-error from supertokens_python.recipe.session import SessionContainer -from supertokens_python.recipe.session.interfaces import TokenInfo +from supertokens_python.recipe.session.exceptions import TokenTheftError +from supertokens_python.recipe.session.interfaces import ( + SessionDoesNotExistError, + TokenInfo, +) from supertokens_python.recipe.session.jwt import ( parse_jwt_without_signature_verification, ) @@ -71,6 +75,59 @@ def get_all_session_handles_for_user_api(): # type: ignore ) return jsonify(response) + @app.route("/test/session/revokeallsessionsforuser", methods=["POST"]) # type: ignore + def revoke_all_sessions_for_user_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + user_id = data["userId"] + revoke_sessions_for_linked_accounts = data.get( + "revokeSessionsForLinkedAccounts", True + ) + tenant_id = data.get("tenantId", None) + user_context = data.get("userContext", {}) + + response = session.revoke_all_sessions_for_user( + user_id, revoke_sessions_for_linked_accounts, tenant_id, user_context + ) + return jsonify(response) + + @app.route("/test/session/refreshsessionwithoutrequestresponse", methods=["POST"]) # type: ignore + def refresh_session_without_request_response(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + refresh_token = data["refreshToken"] + disable_anti_csrf = data.get("disableAntiCsrf") + anti_csrf_token = data.get("antiCsrfToken") + user_context = data.get("userContext", {}) + + try: + response = session.refresh_session_without_request_response( + refresh_token, disable_anti_csrf, anti_csrf_token, user_context + ) + return jsonify(convert_session_to_json(response)) + except Exception as e: + if isinstance(e, TokenTheftError): + return ( + jsonify( + { + "type": "TOKEN_THEFT_DETECTED", + "payload": { + "recipeUserId": { + # this is done this way cause the frontend test suite expects the json in this format + "recipeUserId": e.recipe_user_id.get_as_string() + }, + "userId": e.user_id, + }, + } + ), + 500, + ) + return jsonify({"message": str(e)}), 500 + @app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore def get_session_without_request_response(): # type: ignore data = request.json @@ -138,6 +195,50 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore } ) + @app.route("/test/session/mergeintoaccesspayload", methods=["POST"]) # type: ignore + def merge_into_access_payload(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + access_token_payload_update = data["accessTokenPayloadUpdate"] + user_context = data.get("userContext", {}) + + try: + response = session.merge_into_access_token_payload( + session_handle, access_token_payload_update, user_context + ) + return jsonify(response) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + + @app.route("/test/session/validateclaimsforsessionhandle", methods=["POST"]) # type: ignore + def validate_claims_for_session_handle(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + override_global_claim_validators = None + if "overrideGlobalClaimValidators" in data: + from test_functions_mapper import get_func + + override_global_claim_validators = get_func( + data["overrideGlobalClaimValidators"] + ) + user_context = data.get("userContext", {}) + + try: + response = session.validate_claims_for_session_handle( + session_handle, override_global_claim_validators, user_context + ) + if isinstance(response, SessionDoesNotExistError): + return jsonify({"status": "SESSION_DOES_NOT_EXIST"}) + return jsonify(response.to_json()) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + @app.route("/test/session/getsessioninformation", methods=["POST"]) # type: ignore def get_session_information_api(): # type: ignore data = request.json @@ -164,7 +265,7 @@ def get_session_information_api(): # type: ignore ) @app.route("/test/session/sessionobject/fetchandsetclaim", methods=["POST"]) # type: ignore - def fetch_and_set_claim_api(): # type: ignore + def session_object_fetch_and_set_claim_api(): # type: ignore data = request.json if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) @@ -179,6 +280,23 @@ def fetch_and_set_claim_api(): # type: ignore response = {"updatedSession": convert_session_to_json(session)} return jsonify(response) + @app.route("/test/session/fetchandsetclaim", methods=["POST"]) # type: ignore + def fetch_and_set_claim_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("session.fetchandsetclaim", "CALL", data) + session_handle = data["sessionHandle"] + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + try: + response = session.fetch_and_set_claim(session_handle, claim, user_context) + return jsonify(response) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + @app.route("/test/session/sessionobject/getclaimvalue", methods=["POST"]) # type: ignore def get_claim_value_api(): # type: ignore data = request.json @@ -214,9 +332,11 @@ def convert_session_to_json(session_container: SessionContainer) -> Dict[str, An "frontToken": session_container.get_all_session_tokens_dangerously()[ "frontToken" ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], + "refreshToken": ( + session_container.refresh_token.to_json() + if session_container.refresh_token is not None + else None + ), "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ "antiCsrfToken" ], diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index cc0c80aef..8f320dd44 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -28,6 +28,7 @@ VerificationEmailTemplateVarsUser, ) from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.claims import PrimitiveClaim from supertokens_python.recipe.thirdparty.interfaces import ( SignInUpNotAllowed, SignInUpPostNoEmailGivenByProviderResponse, @@ -104,6 +105,66 @@ async def send_email( return custom_email_delivery_override + elif eval_str.startswith("session.override.functions"): + from supertokens_python.recipe.session.interfaces import ( + RecipeInterface as SessionRecipeInterface, + ) + + def session_override_functions( + original_implementation: SessionRecipeInterface, + ) -> SessionRecipeInterface: + original_create_new_session = original_implementation.create_new_session + + async def create_new_session( + user_id: str, + recipe_user_id: RecipeUserId, + access_token_payload: Optional[Dict[str, Any]], + session_data_in_database: Optional[Dict[str, Any]], + disable_anti_csrf: Optional[bool], + tenant_id: str, + user_context: Dict[str, Any], + ) -> SessionContainer: + async def fetch_value( + _user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> None: + global user_id_in_callback + global recipe_user_id_in_callback + user_id_in_callback = user_id + recipe_user_id_in_callback = recipe_user_id + return None + + claim = PrimitiveClaim[Any](key="some-key", fetch_value=fetch_value) + + if access_token_payload is None: + access_token_payload = {} + json_update = await claim.build( + user_id, + recipe_user_id, + tenant_id, + access_token_payload, + user_context, + ) + access_token_payload.update(json_update) + + return await original_create_new_session( + user_id, + recipe_user_id, + access_token_payload, + session_data_in_database, + disable_anti_csrf, + tenant_id, + user_context, + ) + + original_implementation.create_new_session = create_new_session + return original_implementation + + return session_override_functions + elif eval_str.startswith("emailpassword.init.emailDelivery.override"): def custom_email_deliver( From e654e9451818c47611fc552d681135e993b2adf8 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 16:45:43 +0530 Subject: [PATCH 110/126] fixes stuff --- tests/test-server/app.py | 14 +++++++++++++- tests/test-server/test_functions_mapper.py | 5 ++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index ed406eb1f..5667bc2d3 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -431,7 +431,19 @@ async def custom_unauthorised_callback( thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( providers=providers - ) + ), + override=thirdparty.InputOverrideConfig( + functions=override_builder_with_logging( + "ThirdParty.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + apis=override_builder_with_logging( + "ThirdParty.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + ), ) ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 8f320dd44..0700a248d 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -391,6 +391,9 @@ async def is_email_verified( def tp_override_apis( original_implementation: ThirdPartyAPIInterface, ) -> ThirdPartyAPIInterface: + + og_sign_in_up_post = original_implementation.sign_in_up_post + async def sign_in_up_post( provider: Provider, redirect_uri_info: Optional[RedirectUriInfo], @@ -413,7 +416,7 @@ async def sign_in_up_post( ): user_context["DO_LINK"] = json_body["userContext"]["DO_LINK"] - result = await original_implementation.sign_in_up_post( + result = await og_sign_in_up_post( provider, redirect_uri_info, oauth_tokens, From 2ad8b1b09a95b3103104690f15a9a762a26b08fc Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 18:56:46 +0530 Subject: [PATCH 111/126] fixe stuff --- .../create_or_update_third_party_config.py | 2 +- .../multitenancy/get_third_party_config.py | 6 ++- .../recipe/multitenancy/interfaces.py | 38 ++++++++++++++++--- .../multitenancy/recipe_implementation.py | 26 ++++++------- .../recipe/passwordless/recipe.py | 8 +--- .../thirdparty/providers/active_directory.py | 10 ++--- .../recipe/thirdparty/providers/discord.py | 2 +- tests/test-server/app.py | 5 ++- tests/test-server/multitenancy.py | 15 +++++--- tests/test-server/thirdparty.py | 20 ++++++++++ 10 files changed, 91 insertions(+), 41 deletions(-) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py index 32f403a82..eb6ac3fbd 100644 --- a/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py +++ b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py @@ -155,7 +155,7 @@ async def handle_create_or_update_third_party_config( provider_config["clients"][0]["clientSecret"] = resp["clientSecret"] third_party_res = await create_or_update_third_party_config( - tenant_id, provider_config, None, user_context + tenant_id, ProviderConfig.from_json(provider_config), None, user_context ) return CreateOrUpdateThirdPartyConfigOkResult(third_party_res.created_new) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py index 375e3f065..ef050343b 100644 --- a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py @@ -61,8 +61,10 @@ def to_json(self) -> Dict[str, Any]: "isExchangeAuthCodeForOAuthTokensOverridden" ] = self.is_exchange_auth_code_for_oauth_tokens_overridden json_response["isGetUserInfoOverridden"] = self.is_get_user_info_overridden - json_response["status"] = "OK" - return json_response + return { + "status": "OK", + "providerConfig": json_response, + } class GetThirdPartyConfigUnknownTenantError(APIResponse): diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 52e66b3f7..af65930f5 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -62,19 +62,45 @@ class TenantConfigCreateOrUpdate: def __init__( self, core_config: Dict[str, Any] = {}, - first_factors: Optional[List[str]] = None, - required_secondary_factors: Optional[List[str]] = None, + first_factors: Optional[List[str]] = [ + "NO_CHANGE" + ], # A default value here means that if the user does not set this, it will not make any change in the core + required_secondary_factors: Optional[List[str]] = [ + "NO_CHANGE" + ], # A default value here means that if the user does not set this, it will not make any change in the core ): self.core_config = core_config - self.first_factors = first_factors - self.required_secondary_factors = required_secondary_factors + self._first_factors = first_factors + self._required_secondary_factors = required_secondary_factors + + def is_first_factors_unchanged(self) -> bool: + return self._first_factors == ["NO_CHANGE"] + + def is_required_secondary_factors_unchanged(self) -> bool: + return self._required_secondary_factors == ["NO_CHANGE"] + + def get_first_factors_for_update(self) -> Optional[List[str]]: + if self._first_factors == ["NO_CHANGE"]: + raise Exception( + "First check if the value of first_factors is not NO_CHANGE" + ) + return self._first_factors + + def get_required_secondary_factors_for_update(self) -> Optional[List[str]]: + if self._required_secondary_factors == ["NO_CHANGE"]: + raise Exception( + "First check if the value of required_secondary_factors is not NO_CHANGE" + ) + return self._required_secondary_factors @staticmethod def from_json(json: Dict[str, Any]) -> TenantConfigCreateOrUpdate: return TenantConfigCreateOrUpdate( core_config=json.get("coreConfig", {}), - first_factors=json.get("firstFactors", []), - required_secondary_factors=json.get("requiredSecondaryFactors", []), + first_factors=json.get("firstFactors", ["NO_CHANGE"]), + required_secondary_factors=json.get( + "requiredSecondaryFactors", ["NO_CHANGE"] + ), ) diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index c60f148f8..90578bacb 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -132,21 +132,21 @@ async def create_or_update_tenant( config: Optional[TenantConfigCreateOrUpdate], user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: + json_body: Dict[str, Any] = { + "tenantId": tenant_id, + } + if config is not None: + if not config.is_first_factors_unchanged(): + json_body["firstFactors"] = config.get_first_factors_for_update() + if not config.is_required_secondary_factors_unchanged(): + json_body[ + "requiredSecondaryFactors" + ] = config.get_required_secondary_factors_for_update() + json_body["coreConfig"] = config.core_config + response = await self.querier.send_put_request( NormalisedURLPath("/recipe/multitenancy/tenant/v2"), - { - "tenantId": tenant_id, - "firstFactors": ( - config.first_factors - if config and config.first_factors is not None - else None - ), - "requiredSecondaryFactors": ( - config.required_secondary_factors - if config and config.required_secondary_factors is not None - else None - ), - }, + json_body, user_context=user_context, ) return CreateOrUpdateTenantOkResult( diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index af01c88a0..ae0e752e7 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -72,6 +72,7 @@ from .utils import ( ContactConfig, OverrideConfig, + get_enabled_pwless_factors, validate_and_normalise_user_input, ) from ...post_init_callbacks import PostSTInitCallbacks @@ -155,12 +156,7 @@ def __init__( def callback(): mfa_instance = MultiFactorAuthRecipe.get_instance() - all_factors = [ - FactorIds.OTP_EMAIL, - FactorIds.LINK_EMAIL, - FactorIds.OTP_PHONE, - FactorIds.LINK_PHONE, - ] + all_factors = get_enabled_pwless_factors(self.config) if mfa_instance is not None: async def f1(_: TenantConfig): diff --git a/supertokens_python/recipe/thirdparty/providers/active_directory.py b/supertokens_python/recipe/thirdparty/providers/active_directory.py index 2a4476a90..8ebb00bc1 100644 --- a/supertokens_python/recipe/thirdparty/providers/active_directory.py +++ b/supertokens_python/recipe/thirdparty/providers/active_directory.py @@ -30,13 +30,11 @@ async def get_config_for_client_type( config = await super().get_config_for_client_type(client_type, user_context) if ( - config.additional_config is None - or config.additional_config.get("directoryId") is None + config.additional_config is not None + and config.additional_config.get("directoryId") is not None ): - if not config.oidc_discovery_endpoint: - raise Exception( - "Please provide the directoryId in the additionalConfig of the Active Directory provider." - ) + config.oidc_discovery_endpoint = f"https://login.microsoftonline.com/{config.additional_config['directoryId']}/v2.0/.well-known/openid-configuration" + if config.oidc_discovery_endpoint is not None: config.oidc_discovery_endpoint = ( normalise_oidc_endpoint_to_include_well_known( diff --git a/supertokens_python/recipe/thirdparty/providers/discord.py b/supertokens_python/recipe/thirdparty/providers/discord.py index 677f38032..d301cbcb5 100644 --- a/supertokens_python/recipe/thirdparty/providers/discord.py +++ b/supertokens_python/recipe/thirdparty/providers/discord.py @@ -45,7 +45,7 @@ def Discord(input: ProviderInput) -> Provider: # pylint: disable=redefined-buil input.config.name = "Discord" if not input.config.authorization_endpoint: - input.config.authorization_endpoint = "https://discord.com/api/oauth2/authorize" + input.config.authorization_endpoint = "https://discord.com/oauth2/authorize" if not input.config.token_endpoint: input.config.token_endpoint = "https://discord.com/api/oauth2/token" diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 5667bc2d3..559fe2b6c 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -7,6 +7,7 @@ from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig from supertokens_python.recipe import ( accountlinking, + dashboard, multifactorauth, passwordless, totp, @@ -225,7 +226,9 @@ def init_st(config: Dict[str, Any]): st_reset() override_logging.reset_override_logs() - recipe_list: List[Callable[[AppInfo], RecipeModule]] = [] + recipe_list: List[Callable[[AppInfo], RecipeModule]] = [ + dashboard.init(api_key="test") + ] for recipe_config in config.get("recipeList", []): recipe_id = recipe_config.get("recipeId") if recipe_id == "emailpassword": diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index b617722f5..c7ee4c760 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -20,13 +20,18 @@ def create_or_update_tenant(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data["tenantId"] - config = data["config"] user_context = data.get("userContext") - config = TenantConfigCreateOrUpdate( - first_factors=config.get("firstFactors"), - required_secondary_factors=config.get("requiredSecondaryFactors"), - core_config=config.get("coreConfig"), + config = ( + TenantConfigCreateOrUpdate( + first_factors=data["config"].get("firstFactors"), + required_secondary_factors=data["config"].get( + "requiredSecondaryFactors" + ), + core_config=data["config"].get("coreConfig", {}), + ) + if "config" in data + else None ) response = multitenancy.create_or_update_tenant(tenant_id, config, user_context) diff --git a/tests/test-server/thirdparty.py b/tests/test-server/thirdparty.py index dfc0ff009..513fd5a6b 100644 --- a/tests/test-server/thirdparty.py +++ b/tests/test-server/thirdparty.py @@ -68,3 +68,23 @@ def thirdpartymanuallycreateorupdate(): # type: ignore "reason": response.reason, } ) + + @app.route("/test/thirdparty/getprovider", methods=["POST"]) # type: ignore + def get_provider(): # type: ignore + data = request.get_json() + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + third_party_id = data["thirdPartyId"] + client_type = data.get("clientType", None) + user_context = data.get("userContext", {}) + + from supertokens_python.recipe.thirdparty.syncio import get_provider + + provider = get_provider(tenant_id, third_party_id, client_type, user_context) + + if provider is None: + return jsonify({}) + + return jsonify({"id": provider.id, "config": provider.config.to_json()}) From 937cbb2b9548fea0a568defc3a3334a883edb36c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 19:01:29 +0530 Subject: [PATCH 112/126] adds comments --- supertokens_python/recipe/multitenancy/interfaces.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index af65930f5..ba8a27537 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -64,10 +64,12 @@ def __init__( core_config: Dict[str, Any] = {}, first_factors: Optional[List[str]] = [ "NO_CHANGE" - ], # A default value here means that if the user does not set this, it will not make any change in the core + ], # A default value here means that if the user does not set this, it will not make any change in the core. This is different from None, + # which means that the user wants to unset it in the core. required_secondary_factors: Optional[List[str]] = [ "NO_CHANGE" - ], # A default value here means that if the user does not set this, it will not make any change in the core + ], # A default value here means that if the user does not set this, it will not make any change in the core. This is different from None, + # which means that the user wants to unset it in the core. ): self.core_config = core_config self._first_factors = first_factors From cb9f1e9f02b41964cbbd57a667658a28ce7e7756 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 12 Oct 2024 21:04:54 +0530 Subject: [PATCH 113/126] more fixes --- tests/test-server/override_logging.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index 307ec083f..41be4e322 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -30,6 +30,7 @@ GetEmailForUserIdOkResult, VerifyEmailUsingTokenOkResult, ) +from supertokens_python.recipe.emailverification.recipe import IsVerifiedSCV from supertokens_python.recipe.session.claims import PrimitiveClaim from supertokens_python.recipe.session.interfaces import ( ClaimsValidationResult, @@ -148,4 +149,5 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return "PrimitiveClaim" if isinstance(data, SessionInformationResult): return data.to_json() - return data + if isinstance(data, IsVerifiedSCV): + return "IsVerifiedSCV" From 42420ce1bb885b4343ff41856747e17e8ce9bf29 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 13 Oct 2024 11:53:39 +0530 Subject: [PATCH 114/126] fixes stuff --- tests/test-server/app.py | 4 ++ tests/test-server/test_functions_mapper.py | 74 ++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 559fe2b6c..712796c34 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -506,6 +506,10 @@ async def custom_unauthorised_callback( "functions", None ), ), + apis=override_builder_with_logging( + "MultifactorAuth.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), ), ) ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 0700a248d..5f76f915e 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -27,6 +27,10 @@ from supertokens_python.recipe.emailverification.types import ( VerificationEmailTemplateVarsUser, ) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.types import MFARequirementList from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.claims import PrimitiveClaim from supertokens_python.recipe.thirdparty.interfaces import ( @@ -72,6 +76,76 @@ async def on_account_linked( return on_account_linked + elif eval_str.startswith("multifactorauth.init.override.apis"): + from supertokens_python.recipe.multifactorauth.interfaces import ( + APIInterface as MFAAPIInterface, + APIOptions as MFAAPIOptions, + ) + + def mfa_override_apis( + original_implementation: MFAAPIInterface, + ) -> MFAAPIInterface: + original_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAAPIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + json_body = await api_options.request.json() + if ( + json_body is not None + and json_body.get("userContext", {}).get("requireFactor") + is not None + ): + user_context["requireFactor"] = json_body["userContext"][ + "requireFactor" + ] + + return await original_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + return mfa_override_apis + + elif eval_str.startswith("multifactorauth.init.override.functions"): + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + ) + + def mfa_override_functions( + original_implementation: MFARecipeInterface, + ) -> MFARecipeInterface: + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Any, + factors_set_up_for_user: Any, + required_secondary_factors_for_user: Any, + required_secondary_factors_for_tenant: Any, + user_context: Dict[str, Any], + ) -> MFARequirementList: + return ( + MFARequirementList("otp-phone") + if user_context.get("requireFactor") + else MFARequirementList() + ) + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + return original_implementation + + return mfa_override_functions + elif eval_str.startswith("emailverification.init.emailDelivery.override"): from supertokens_python.recipe.emailverification.types import ( EmailDeliveryOverrideInput as EVEmailDeliveryOverrideInput, From a0b0aafaec8e1b8629bd0c4f937c0e9e14db6d41 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 13 Oct 2024 12:52:29 +0530 Subject: [PATCH 115/126] fixes stuff --- tests/test-server/app.py | 38 +++-- tests/test-server/multifactorauth.py | 200 +++++++++++++++++++++++++++ tests/test-server/supertokens.py | 16 +-- tests/test-server/usermetadata.py | 40 ++++++ 4 files changed, 273 insertions(+), 21 deletions(-) create mode 100644 tests/test-server/multifactorauth.py create mode 100644 tests/test-server/usermetadata.py diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 712796c34..09cc9d248 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -657,19 +657,24 @@ def inner( res.get("body"), ) - init( - app_info=InputAppInfo( - app_name=config["appInfo"]["appName"], - api_domain=config["appInfo"]["apiDomain"], - website_domain=config["appInfo"]["websiteDomain"], - ), - supertokens_config=SupertokensConfig( - connection_uri=config["supertokens"]["connectionURI"], - network_interceptor=network_interceptor_func, - ), - framework="flask", - recipe_list=recipe_list, - ) + try: + init( + app_info=InputAppInfo( + app_name=config["appInfo"]["appName"], + api_domain=config["appInfo"]["apiDomain"], + website_domain=config["appInfo"]["websiteDomain"], + ), + supertokens_config=SupertokensConfig( + connection_uri=config["supertokens"]["connectionURI"], + network_interceptor=network_interceptor_func, + ), + framework="flask", + recipe_list=recipe_list, + ) + except Exception as e: + st_reset() + default_st_init() + raise e # Routes @@ -808,6 +813,13 @@ def handle_exception(e: Exception): from supertokens import add_supertokens_routes # pylint: disable=import-error add_supertokens_routes(app) +from usermetadata import add_usermetadata_routes + +add_usermetadata_routes(app) + +from multifactorauth import add_multifactorauth_routes + +add_multifactorauth_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/multifactorauth.py b/tests/test-server/multifactorauth.py new file mode 100644 index 000000000..d82048917 --- /dev/null +++ b/tests/test-server/multifactorauth.py @@ -0,0 +1,200 @@ +from typing import List +from flask import Flask, request, jsonify + +from supertokens_python import async_to_sync_wrapper +from supertokens_python.recipe.multifactorauth.types import MFAClaimValue +from supertokens_python.types import User + + +def add_multifactorauth_routes(app: Flask): + @app.route("/test/multifactorauthclaim/fetchvalue", methods=["POST"]) # type: ignore + def fetch_value_api(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + ) + + assert request.json is not None + response: MFAClaimValue = async_to_sync_wrapper.sync( # type: ignore + MultiFactorAuthClaim.fetch_value( # type: ignore + request.json["_userId"], + convert_to_recipe_user_id(request.json["recipeUserId"]), + request.json["tenantId"], + request.json["currentPayload"], + request.json.get("userContext"), + ) + ) + return jsonify( + { + "c": response.c, # type: ignore + "v": response.v, # type: ignore + } + ) + + @app.route("/test/multifactorauth/getfactorssetupforuser", methods=["POST"]) # type: ignore + def get_factors_setup_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_factors_setup_for_user, + ) + + assert request.json is not None + user_id = request.json["userId"] + user_context = request.json.get("userContext") + + response = get_factors_setup_for_user( + user_id=user_id, + user_context=user_context, + ) + return jsonify(response) + + @app.route("/test/assertallowedtosetupfactorelsethowinvalidclaimerror", methods=["POST"]) # type: ignore + def assert_allowed_to_setup_factor_else_throw_invalid_claim_error_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/getmfarequirementsforauth", methods=["POST"]) # type: ignore + def get_mfa_requirements_for_auth_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_mfa_requirements_for_auth, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + response = get_mfa_requirements_for_auth( + session=session, + user_context=request.json.get("userContext"), + ) + + return jsonify(response) + + @app.route("/test/multifactorauth/markfactorascompleteinsession", methods=["POST"]) # type: ignore + def mark_factor_as_complete_in_session_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + mark_factor_as_complete_in_session, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + mark_factor_as_complete_in_session( + session=session, + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/getrequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def get_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_required_secondary_factors_for_user, + ) + + assert request.json is not None + + response = get_required_secondary_factors_for_user( + user_id=request.json["userId"], + user_context=request.json.get("userContext"), + ) + + return jsonify(response) + + @app.route("/test/multifactorauth/addtorequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def add_to_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + add_to_required_secondary_factors_for_user, + ) + + assert request.json is not None + + add_to_required_secondary_factors_for_user( + user_id=request.json["userId"], + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/removefromrequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def remove_from_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + remove_from_required_secondary_factors_for_user, + ) + + assert request.json is not None + + remove_from_required_secondary_factors_for_user( + user_id=request.json["userId"], + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/recipeimplementation.getmfarequirementsforauth", methods=["POST"]) # type: ignore + def get_mfa_requirements_for_auth_api2(): # type: ignore + assert request.json is not None + from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe, + ) + + async def user() -> User: + assert request.json is not None + user_json = request.json["user"] + assert user_json is not None + return User.from_json(user_json) + + async def factors_set_up_for_user() -> List[str]: + assert request.json is not None + return request.json["factorsSetUpForUser"] + + async def required_secondary_factors_for_user() -> List[str]: + assert request.json is not None + return request.json["requiredSecondaryFactorsForUser"] + + async def required_secondary_factors_for_tenant() -> List[str]: + assert request.json is not None + return request.json["requiredSecondaryFactorsForTenant"] + + response = async_to_sync_wrapper.sync( + MultiFactorAuthRecipe.get_instance_or_throw_error().recipe_implementation.get_mfa_requirements_for_auth( + tenant_id=request.json["tenantId"], + access_token_payload=request.json["accessTokenPayload"], + completed_factors=request.json["completedFactors"], + user=user, + factors_set_up_for_user=factors_set_up_for_user, + required_secondary_factors_for_user=required_secondary_factors_for_user, + required_secondary_factors_for_tenant=required_secondary_factors_for_tenant, + user_context=request.json.get("userContext"), + ) + ) + + return jsonify(response) diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py index 449427a24..a570e149a 100644 --- a/tests/test-server/supertokens.py +++ b/tests/test-server/supertokens.py @@ -56,10 +56,10 @@ def list_users_by_account_info_api(): # type: ignore def get_users_newest_first_api(): # type: ignore assert request.json is not None response = get_users_newest_first( - include_recipe_ids=request.json["includeRecipeIds"], - limit=request.json["limit"], - pagination_token=request.json["paginationToken"], - tenant_id=request.json["tenantId"], + include_recipe_ids=request.json.get("includeRecipeIds"), + limit=request.json.get("limit"), + pagination_token=request.json.get("paginationToken"), + tenant_id=request.json.get("tenantId"), user_context=request.json.get("userContext"), ) return jsonify( @@ -73,10 +73,10 @@ def get_users_newest_first_api(): # type: ignore def get_users_oldest_first_api(): # type: ignore assert request.json is not None response = get_users_oldest_first( - include_recipe_ids=request.json["includeRecipeIds"], - limit=request.json["limit"], - pagination_token=request.json["paginationToken"], - tenant_id=request.json["tenantId"], + include_recipe_ids=request.json.get("includeRecipeIds"), + limit=request.json.get("limit"), + pagination_token=request.json.get("paginationToken"), + tenant_id=request.json.get("tenantId"), user_context=request.json.get("userContext"), ) return jsonify( diff --git a/tests/test-server/usermetadata.py b/tests/test-server/usermetadata.py new file mode 100644 index 000000000..a4a960774 --- /dev/null +++ b/tests/test-server/usermetadata.py @@ -0,0 +1,40 @@ +from flask import Flask, request, jsonify + +from supertokens_python.recipe.usermetadata.syncio import ( + get_user_metadata, + update_user_metadata, + clear_user_metadata, +) + + +def add_usermetadata_routes(app: Flask): + @app.route("/test/usermetadata/getusermetadata", methods=["POST"]) # type: ignore + def get_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + response = get_user_metadata( + user_id=user_id, user_context=request.json.get("userContext") + ) + return jsonify({"metadata": response.metadata}) + + @app.route("/test/usermetadata/updateusermetadata", methods=["POST"]) # type: ignore + def update_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + metadata_update = request.json["metadataUpdate"] + + response = update_user_metadata( + user_id=user_id, + metadata_update=metadata_update, + user_context=request.json.get("userContext"), + ) + return jsonify({"metadata": response.metadata}) + + @app.route("/test/usermetadata/clearusermetadata", methods=["POST"]) # type: ignore + def clear_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + clear_user_metadata( + user_id=user_id, user_context=request.json.get("userContext") + ) + return jsonify({"status": "OK"}) From 21ba18e475edefc6245de1149d7c587c41a6f5dd Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 13 Oct 2024 13:37:34 +0530 Subject: [PATCH 116/126] fixes stuff --- tests/test-server/override_logging.py | 2 ++ tests/test-server/test_functions_mapper.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index 41be4e322..b29cf1b4a 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -151,3 +151,5 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return data.to_json() if isinstance(data, IsVerifiedSCV): return "IsVerifiedSCV" + + return data diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 5f76f915e..df95a83a2 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -346,8 +346,8 @@ async def consume_code_post( ) -> Any: o = await api_options.request.json() assert o is not None - if o.get("user_context", {}).get("DO_LINK") is not None: - user_context["DO_LINK"] = o["user_context"]["DO_LINK"] + if o.get("userContext", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = o["userContext"]["DO_LINK"] return await og( pre_auth_session_id, user_input_code, @@ -417,8 +417,8 @@ async def sign_up_post( if "signUpPOST" in eval_str: n = await api_options.request.json() assert n is not None - if n.get("user_context", {}).get("DO_LINK") is not None: - user_context["DO_LINK"] = n["user_context"]["DO_LINK"] + if n.get("userContext", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = n["userContext"]["DO_LINK"] return await og_sign_up_post( form_fields, tenant_id, From f6cf30402523f20ba72158abb897783e3bb43ba9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 13 Oct 2024 14:23:42 +0530 Subject: [PATCH 117/126] fixes a small issue --- supertokens_python/post_init_callbacks.py | 4 ++ tests/test-server/app.py | 2 + tests/test-server/override_logging.py | 13 +++++- tests/test-server/test_functions_mapper.py | 50 ++++++++++++++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/supertokens_python/post_init_callbacks.py b/supertokens_python/post_init_callbacks.py index 982e1b741..ddbf0afa9 100644 --- a/supertokens_python/post_init_callbacks.py +++ b/supertokens_python/post_init_callbacks.py @@ -29,3 +29,7 @@ def run_post_init_callbacks() -> None: for cb in PostSTInitCallbacks.post_init_callbacks: cb() PostSTInitCallbacks.post_init_callbacks = [] + + @staticmethod + def reset(): + PostSTInitCallbacks.post_init_callbacks = [] diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 09cc9d248..bd08f1674 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -5,6 +5,7 @@ from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig +from supertokens_python.post_init_callbacks import PostSTInitCallbacks from supertokens_python.recipe import ( accountlinking, dashboard, @@ -202,6 +203,7 @@ async def default_func( # pylint: disable=unused-argument def st_reset(): + PostSTInitCallbacks.reset() override_logging.reset_override_logs() reset_override_params() ProcessState.get_instance().reset() diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index b29cf1b4a..03a84328d 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -1,3 +1,4 @@ +import json from typing import Any, Callable, Coroutine, Dict, List, Set, Union import time @@ -151,5 +152,15 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return data.to_json() if isinstance(data, IsVerifiedSCV): return "IsVerifiedSCV" + if is_jsonable(data): + return data - return data + return "Some custom object" + + +def is_jsonable(x: Any) -> bool: + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index df95a83a2..52fc98d91 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -1,6 +1,7 @@ from typing import Callable, List, Union from typing import Dict, Any, Optional from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.accountlinking import ( RecipeLevelUser, ShouldAutomaticallyLink, @@ -35,6 +36,7 @@ from supertokens_python.recipe.session.claims import PrimitiveClaim from supertokens_python.recipe.thirdparty.interfaces import ( SignInUpNotAllowed, + SignInUpOkResult, SignInUpPostNoEmailGivenByProviderResponse, SignInUpPostOkResult, ) @@ -309,6 +311,54 @@ def func1( return func1 + if eval_str.startswith("thirdparty.init.override.functions"): + if "setIsVerifiedInSignInUp" in eval_str: + from supertokens_python.recipe.thirdparty.interfaces import ( + RecipeInterface as ThirdPartyRecipeInterface, + ) + + def custom_override( + original_implementation: ThirdPartyRecipeInterface, + ) -> ThirdPartyRecipeInterface: + og_sign_in_up = original_implementation.sign_in_up + + async def sign_in_up( + third_party_id: str, + third_party_user_id: str, + email: str, + is_verified: bool, + oauth_tokens: Dict[str, Any], + raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + SignInUpOkResult, + SignInUpNotAllowed, + LinkingToSessionUserFailedError, + ]: + user_context[ + "isVerified" + ] = is_verified # this information comes from the third party provider + return await og_sign_in_up( + third_party_id, + third_party_user_id, + email, + is_verified, + oauth_tokens, + raw_user_info_from_provider, + session, + should_try_linking_with_session_user, + tenant_id, + user_context, + ) + + original_implementation.sign_in_up = sign_in_up + return original_implementation + + return custom_override + elif eval_str.startswith("passwordless.init.smsDelivery.service.sendSms"): def func2( From b8fd8d0020af54625fdb6ba970c24fb848945c88 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sun, 13 Oct 2024 21:11:31 +0530 Subject: [PATCH 118/126] makes type simpler --- .../multifactorauth/recipe_implementation.py | 2 +- .../recipe/multifactorauth/types.py | 21 +++---------------- tests/test-server/test_functions_mapper.py | 6 +----- 3 files changed, 5 insertions(+), 24 deletions(-) diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py index fc5cd1c90..476582eab 100644 --- a/supertokens_python/recipe/multifactorauth/recipe_implementation.py +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -147,7 +147,7 @@ async def get_mfa_requirements_for_auth( all_factors.add(factor) for factor in await required_secondary_factors_for_tenant(): all_factors.add(factor) - return MFARequirementList({"oneOf": list(all_factors)}) + return [{"oneOf": list(all_factors)}] async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( self, diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index 779c8ad77..a53f12e30 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -22,24 +22,9 @@ from .interfaces import RecipeInterface, APIInterface -class MFARequirementList(List[Union[Dict[str, List[str]], str]]): - def __init__( - self, - *args: Union[ - str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]] - ], - ): - super().__init__() - for arg in args: - if isinstance(arg, str): - self.append(arg) - else: - if "oneOf" in arg: - self.append({"oneOf": arg["oneOf"]}) - elif "allOfInAnyOrder" in arg: - self.append({"allOfInAnyOrder": arg["allOfInAnyOrder"]}) - else: - raise ValueError("Invalid dictionary format") +MFARequirementList = List[ + Union[str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]]] +] class MFAClaimValue: diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 52fc98d91..ede8502f2 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -135,11 +135,7 @@ async def get_mfa_requirements_for_auth( required_secondary_factors_for_tenant: Any, user_context: Dict[str, Any], ) -> MFARequirementList: - return ( - MFARequirementList("otp-phone") - if user_context.get("requireFactor") - else MFARequirementList() - ) + return ["otp-phone"] if user_context.get("requireFactor") else [] original_implementation.get_mfa_requirements_for_auth = ( get_mfa_requirements_for_auth From 5236c14f807a99ac88692e03ead769416722859f Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 14 Oct 2024 15:57:11 +0530 Subject: [PATCH 119/126] fixes more stuff --- supertokens_python/constants.py | 2 +- .../emailpassword/api/implementation.py | 217 ++++++++++++------ .../recipe/emailpassword/types.py | 5 +- tests/test-server/test_functions_mapper.py | 23 ++ 4 files changed, 169 insertions(+), 78 deletions(-) diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 9225cf37a..62464cd86 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -28,6 +28,6 @@ FDI_KEY_HEADER = "fdi-version" API_VERSION = "/apiversion" API_VERSION_HEADER = "cdi-version" -DASHBOARD_VERSION = "0.7" +DASHBOARD_VERSION = "0.13" ONE_YEAR_IN_MS = 31536000000 RATE_LIMIT_STATUS_CODE = 429 diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index 8d1a483fa..39490b324 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -176,11 +176,50 @@ async def generate_and_send_password_reset_token( None, ) - primary_user_associated_with_email = next( - (u for u in users if u.is_primary_user), None + linking_candidate = next((u for u in users if u.is_primary_user), None) + + # first we check if there even exists a primary user that has the input email + log_debug_message( + f"generatePasswordResetTokenPOST: primary linking candidate: {linking_candidate.id if linking_candidate else None}" + ) + log_debug_message( + f"generatePasswordResetTokenPOST: linking candidate count {len(users)}" ) - if primary_user_associated_with_email is None: + # If there is no existing primary user and there is a single option to link + # we see if that user can become primary (and a candidate for linking) + if linking_candidate is None and len(users) > 0: + # If the only user that exists with this email is a non-primary emailpassword user, then we can just let them reset their password, because: + # we are not going to link anything and there is no risk of account takeover. + if ( + email_password_account is not None + and len(users) == 1 + and users[0].login_methods[0].recipe_user_id.get_as_string() + == email_password_account.recipe_user_id.get_as_string() + ): + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, + ) + + oldest_user = min(users, key=lambda u: u.time_joined) + log_debug_message( + f"generatePasswordResetTokenPOST: oldest recipe level-linking candidate: {oldest_user.id} (w/ {'verified' if oldest_user.login_methods[0].verified else 'unverified'} email)" + ) + # Otherwise, we check if the user can become primary. + should_become_primary_user = ( + await AccountLinkingRecipe.get_instance().should_become_primary_user( + oldest_user, tenant_id, None, user_context + ) + ) + + log_debug_message( + f"generatePasswordResetTokenPOST: recipe level-linking candidate {'can' if should_become_primary_user else 'can not'} become primary" + ) + if should_become_primary_user: + linking_candidate = oldest_user + + if linking_candidate is None: if email_password_account is None: log_debug_message( f"Password reset email not sent, unknown user email: {email}" @@ -193,13 +232,13 @@ async def generate_and_send_password_reset_token( email_verified = any( lm.has_same_email_as(email) and lm.verified - for lm in primary_user_associated_with_email.login_methods + for lm in linking_candidate.login_methods ) has_other_email_or_phone = any( (lm.email is not None and not lm.has_same_email_as(email)) or lm.phone_number is not None - for lm in primary_user_associated_with_email.login_methods + for lm in linking_candidate.login_methods ) if not email_verified and has_other_email_or_phone: @@ -207,12 +246,27 @@ async def generate_and_send_password_reset_token( "Reset password link was not created because of account take over risk. Please contact support. (ERR_CODE_001)" ) + if linking_candidate.is_primary_user and email_password_account is not None: + # If a primary user has the input email as verified or has no other emails then it is always allowed to reset their own password: + # - there is no risk of account takeover, because they have verified this email or haven't linked it to anything else (checked above this block) + # - there will be no linking as a result of this action, so we do not need to check for linking (checked here by seeing that the two accounts are already linked) + are_the_two_accounts_linked = any( + lm.recipe_user_id.get_as_string() + == email_password_account.recipe_user_id.get_as_string() + for lm in linking_candidate.login_methods + ) + + if are_the_two_accounts_linked: + return await generate_and_send_password_reset_token( + linking_candidate.id, email_password_account.recipe_user_id + ) + should_do_account_linking_response = await AccountLinkingRecipe.get_instance().config.should_do_automatic_account_linking( AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( email_password_account or AccountInfoWithRecipeId(email=email, recipe_id="emailpassword") ), - primary_user_associated_with_email, + linking_candidate, None, tenant_id, user_context, @@ -240,7 +294,7 @@ async def generate_and_send_password_reset_token( ) if is_sign_up_allowed: return await generate_and_send_password_reset_token( - primary_user_associated_with_email.id, None + linking_candidate.id, None ) else: log_debug_message( @@ -248,32 +302,14 @@ async def generate_and_send_password_reset_token( ) return GeneratePasswordResetTokenPostOkResult() - are_the_two_accounts_linked = any( - lm.recipe_user_id.get_as_string() - == email_password_account.recipe_user_id.get_as_string() - for lm in primary_user_associated_with_email.login_methods - ) - - if are_the_two_accounts_linked: - return await generate_and_send_password_reset_token( - primary_user_associated_with_email.id, - email_password_account.recipe_user_id, - ) - if isinstance(should_do_account_linking_response, ShouldNotAutomaticallyLink): return await generate_and_send_password_reset_token( email_password_account.recipe_user_id.get_as_string(), email_password_account.recipe_user_id, ) - if not should_do_account_linking_response.should_require_verification: - return await generate_and_send_password_reset_token( - primary_user_associated_with_email.id, - email_password_account.recipe_user_id, - ) - return await generate_and_send_password_reset_token( - primary_user_associated_with_email.id, email_password_account.recipe_user_id + linking_candidate.id, email_password_account.recipe_user_id ) async def password_reset_post( @@ -388,71 +424,100 @@ async def do_update_password_and_verify_email_and_try_link_if_not_primary( user_id_for_whom_token_was_generated = token_consumption_response.user_id email_for_whom_token_was_generated = token_consumption_response.email - existing_user = await get_user(token_consumption_response.user_id, user_context) + existing_user = await get_user( + user_id_for_whom_token_was_generated, user_context + ) if existing_user is None: return PasswordResetTokenInvalidError() - if existing_user.is_primary_user: - email_password_user_is_linked_to_existing_user = any( - lm.recipe_user_id.get_as_string() - == user_id_for_whom_token_was_generated - and lm.recipe_id == "emailpassword" - for lm in existing_user.login_methods - ) + token_generated_for_email_password_user = any( + lm.recipe_user_id.get_as_string() == user_id_for_whom_token_was_generated + and lm.recipe_id == "emailpassword" + for lm in existing_user.login_methods + ) - if email_password_user_is_linked_to_existing_user: + if token_generated_for_email_password_user: + if not existing_user.is_primary_user: + # If this is a recipe level emailpassword user, we can always allow them to reset their password. return await do_update_password_and_verify_email_and_try_link_if_not_primary( RecipeUserId(user_id_for_whom_token_was_generated) ) - else: - create_user_response = ( - await api_options.recipe_implementation.create_new_recipe_user( - tenant_id=tenant_id, - email=token_consumption_response.email, - password=new_password, - user_context=user_context, - ) + + # If the user is a primary user resetting the password of an emailpassword user linked to it + # we need to check for account takeover risk (similar to what we do when generating the token) + + # We check if there is any login method in which the input email is verified. + # If that is the case, then it's proven that the user owns the email and we can + # trust linking of the email password account. + email_verified = any( + lm.has_same_email_as(email_for_whom_token_was_generated) and lm.verified + for lm in existing_user.login_methods + ) + + # finally, we check if the primary user has any other email / phone number + # associated with this account - and if it does, then it means that + # there is a risk of account takeover, so we do not allow the token to be generated + has_other_email_or_phone = any( + ( + lm.email is not None + and not lm.has_same_email_as(email_for_whom_token_was_generated) ) - if isinstance(create_user_response, EmailAlreadyExistsError): - return PasswordResetTokenInvalidError() - else: - await mark_email_as_verified( - create_user_response.user.login_methods[0].recipe_user_id, - token_consumption_response.email, - ) - updated_user = await get_user( - create_user_response.user.id, - user_context, - ) - if updated_user is None: - raise Exception( - "Should never happen - user deleted after during password reset" - ) - create_user_response.user = updated_user - link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( - tenant_id=tenant_id, - input_user=create_user_response.user, - session=None, - user_context=user_context, - ) - user_after_linking = ( - link_res.user - if link_res.status == "OK" - else create_user_response.user - ) - assert user_after_linking is not None - return PasswordResetPostOkResult( - user=user_after_linking, - email=token_consumption_response.email, - ) - else: + or lm.phone_number is not None + for lm in existing_user.login_methods + ) + + if not email_verified and has_other_email_or_phone: + # We can return an invalid token error, because in this case the token should not have been created + # whenever they try to re-create it they'll see the appropriate error message + return PasswordResetTokenInvalidError() + + # since this doesn't result in linking and there is no risk of account takeover, we can allow the password reset to proceed return ( await do_update_password_and_verify_email_and_try_link_if_not_primary( RecipeUserId(user_id_for_whom_token_was_generated) ) ) + create_user_response = ( + await api_options.recipe_implementation.create_new_recipe_user( + tenant_id=tenant_id, + email=token_consumption_response.email, + password=new_password, + user_context=user_context, + ) + ) + if isinstance(create_user_response, EmailAlreadyExistsError): + return PasswordResetTokenInvalidError() + else: + await mark_email_as_verified( + create_user_response.user.login_methods[0].recipe_user_id, + token_consumption_response.email, + ) + updated_user = await get_user( + create_user_response.user.id, + user_context, + ) + if updated_user is None: + raise Exception( + "Should never happen - user deleted after during password reset" + ) + create_user_response.user = updated_user + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=create_user_response.user, + session=None, + user_context=user_context, + ) + user_after_linking = ( + link_res.user if link_res.status == "OK" else create_user_response.user + ) + assert user_after_linking is not None + return PasswordResetPostOkResult( + user=user_after_linking, + email=token_consumption_response.email, + ) + async def sign_in_post( self, form_fields: List[FormField], diff --git a/supertokens_python/recipe/emailpassword/types.py b/supertokens_python/recipe/emailpassword/types.py index e560f2e25..890ec14b4 100644 --- a/supertokens_python/recipe/emailpassword/types.py +++ b/supertokens_python/recipe/emailpassword/types.py @@ -76,7 +76,7 @@ def __init__( self.email = email def to_json(self) -> Dict[str, Any]: - return { + resp_json = { "id": self.id, "recipeUserId": ( self.recipe_user_id.get_as_string() @@ -85,6 +85,8 @@ def to_json(self) -> Dict[str, Any]: ), "email": self.email, } + # Remove items that are None + return {k: v for k, v in resp_json.items() if v is not None} class PasswordResetEmailTemplateVars: @@ -100,6 +102,7 @@ def __init__( def to_json(self) -> Dict[str, Any]: return { + "type": "PASSWORD_RESET", "user": self.user.to_json(), "passwordResetLink": self.password_reset_link, "tenantId": self.tenant_id, diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index ede8502f2..af5651978 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -559,6 +559,29 @@ async def sign_in_up_post( return tp_override_apis elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): + if "onlyLinkIfNewUserVerified" in eval_str: + + async def func4( + new_user_account: Any, + existing_user: Any, + session: Any, + tenant_id: Any, + user_context: Dict[str, Any], + ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + if user_context.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + + if ( + new_user_account.third_party is not None + and existing_user is not None + ): + if user_context.get("isVerified"): + return ShouldAutomaticallyLink(should_require_verification=True) + return ShouldNotAutomaticallyLink() + + return ShouldAutomaticallyLink(should_require_verification=True) + + return func4 async def func( i: Any, l: Any, o: Any, u: Any, a: Any # pylint: disable=unused-argument From b106a0ddb3c14f8a082e0568d5599dcddddd3905 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Mon, 14 Oct 2024 16:06:57 +0530 Subject: [PATCH 120/126] removes overwrite session flag --- supertokens_python/auth_utils.py | 24 +++++-------------- supertokens_python/recipe/session/__init__.py | 2 -- supertokens_python/recipe/session/recipe.py | 4 ---- supertokens_python/recipe/session/utils.py | 8 ------- tests/test-server/app.py | 5 +--- 5 files changed, 7 insertions(+), 36 deletions(-) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py index e471ea522..fe5a03eaf 100644 --- a/supertokens_python/auth_utils.py +++ b/supertokens_python/auth_utils.py @@ -21,7 +21,6 @@ ) from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant from supertokens_python.recipe.session.interfaces import SessionContainer -from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.asyncio import create_new_session, get_session from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo from supertokens_python.types import ( @@ -249,17 +248,13 @@ async def post_auth_checks( # If the new user wasn't linked to the current one, we check the config and overwrite the session if required # Note: we could also get here if MFA is enabled, but the app didn't want to link the user to the session user. # This is intentional, since the MFA and overwriteSessionDuringSignInUp configs should work independently. - overwrite_session_during_sign_in_up = ( - SessionRecipe.get_instance().config.overwrite_session_during_sign_in_up + resp_session = await create_new_session( + request, tenant_id, recipe_user_id, {}, {}, user_context ) - if overwrite_session_during_sign_in_up: - resp_session = await create_new_session( - request, tenant_id, recipe_user_id, {}, {}, user_context + if mfa_instance is not None: + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context ) - if mfa_instance is not None: - await mark_factor_as_complete_in_session( - resp_session, factor_id, user_context - ) else: log_debug_message("postAuthChecks creating session for first factor sign in/up") # If there is no input session, we do not need to do anything other checks and create a new session @@ -993,14 +988,7 @@ async def load_session_in_auth_api_if_needed( user_context: Dict[str, Any], ) -> Optional[SessionContainer]: - overwrite_session_during_sign_in_up = ( - SessionRecipe.get_instance().config.overwrite_session_during_sign_in_up - ) - - if ( - should_try_linking_with_session_user is not False - or not overwrite_session_during_sign_in_up - ): + if should_try_linking_with_session_user is not False: return await get_session( request, session_required=should_try_linking_with_session_user is True, diff --git a/supertokens_python/recipe/session/__init__.py b/supertokens_python/recipe/session/__init__.py index fae46a384..bd158dccb 100644 --- a/supertokens_python/recipe/session/__init__.py +++ b/supertokens_python/recipe/session/__init__.py @@ -52,7 +52,6 @@ def init( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, - overwrite_session_during_sign_in_up: Union[bool, None] = None, ) -> Callable[[AppInfo], RecipeModule]: return SessionRecipe.init( cookie_domain, @@ -68,5 +67,4 @@ def init( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, - overwrite_session_during_sign_in_up, ) diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index f2e221d89..7e4eb4799 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -93,7 +93,6 @@ def __init__( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, - overwrite_session_during_sign_in_up: Union[bool, None] = None, ): super().__init__(recipe_id, app_info) self.config = validate_and_normalise_user_input( @@ -111,7 +110,6 @@ def __init__( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, - overwrite_session_during_sign_in_up, ) self.openid_recipe = OpenIdRecipe( recipe_id, @@ -312,7 +310,6 @@ def init( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, - overwrite_session_during_sign_in_up: Union[bool, None] = None, ): def func(app_info: AppInfo): if SessionRecipe.__instance is None: @@ -332,7 +329,6 @@ def func(app_info: AppInfo): use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, - overwrite_session_during_sign_in_up, ) return SessionRecipe.__instance raise_general_exception( diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 13f3d8dca..96e5c43a4 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -391,7 +391,6 @@ def __init__( use_dynamic_access_token_signing_key: bool, expose_access_token_to_frontend_in_cookie_based_auth: bool, jwks_refresh_interval_sec: int, - overwrite_session_during_sign_in_up: bool, ): self.session_expired_status_code = session_expired_status_code self.invalid_claim_status_code = invalid_claim_status_code @@ -412,7 +411,6 @@ def __init__( self.framework = framework self.mode = mode self.jwks_refresh_interval_sec = jwks_refresh_interval_sec - self.overwrite_session_during_sign_in_up = overwrite_session_during_sign_in_up def validate_and_normalise_user_input( @@ -436,7 +434,6 @@ def validate_and_normalise_user_input( use_dynamic_access_token_signing_key: Union[bool, None] = None, expose_access_token_to_frontend_in_cookie_based_auth: Union[bool, None] = None, jwks_refresh_interval_sec: Union[int, None] = None, - overwrite_session_during_sign_in_up: Union[bool, None] = None, ): _ = cookie_same_site # we have this otherwise pylint complains that cookie_same_site is unused, but it is being used in the get_cookie_same_site function. if anti_csrf not in {"VIA_TOKEN", "VIA_CUSTOM_HEADER", "NONE", None}: @@ -564,11 +561,6 @@ def anti_csrf_function( use_dynamic_access_token_signing_key, expose_access_token_to_frontend_in_cookie_based_auth, jwks_refresh_interval_sec, - ( - overwrite_session_during_sign_in_up - if overwrite_session_during_sign_in_up is not None - else False - ), ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index bd08f1674..8ee45dd7e 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -301,9 +301,6 @@ async def custom_unauthorised_callback( use_dynamic_access_token_signing_key=recipe_config_json.get( "useDynamicAccessTokenSigningKey" ), - overwrite_session_during_sign_in_up=recipe_config_json.get( - "overwriteSessionDuringSignInUp", None - ), override=session.InputOverrideConfig( apis=override_builder_with_logging( "Session.override.apis", @@ -703,7 +700,7 @@ def override_params(): @app.route("/test/featureflag", methods=["GET"]) # type: ignore def feature_flag(): - return jsonify([]) + return jsonify(["removedOverwriteSessionDuringSignInUp"]) @app.route("/test/resetoverrideparams", methods=["POST"]) # type: ignore From edddad062e2ff49ec13a8e58fa3129862289f18c Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Sat, 19 Oct 2024 16:57:33 +0530 Subject: [PATCH 121/126] fixes type --- supertokens_python/recipe/multifactorauth/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py index a53f12e30..1ffc2ec9f 100644 --- a/supertokens_python/recipe/multifactorauth/types.py +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -28,7 +28,7 @@ class MFAClaimValue: - c: Dict[str, Any] + c: Dict[str, int] v: bool def __init__(self, c: Dict[str, Any], v: bool): From 125264903c5dc2f1866828d3bf30943bd9309829 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Thu, 24 Oct 2024 23:46:27 +0530 Subject: [PATCH 122/126] changes versions --- setup.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index a1236113b..0550005b3 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ "Fastapi", "uvicorn==0.18.2", "python-dotenv==0.19.2", - "pyotp==2.9.0", + "pyotp<3", "aiofiles==24.1.0", ] ), @@ -28,7 +28,7 @@ "flask_cors", "Flask", "python-dotenv==0.19.2", - "pyotp==2.9.0", + "pyotp<3", ] ), "django": ( @@ -38,7 +38,7 @@ "django-stubs==1.9.0", "uvicorn==0.18.2", "python-dotenv==0.19.2", - "pyotp==2.9.0", + "pyotp<3", ] ), "django2x": ( @@ -48,7 +48,7 @@ "django-stubs==1.9.0", "gunicorn==20.1.0", "python-dotenv==0.19.2", - "pyotp==2.9.0", + "pyotp<3", ] ), "drf": ( @@ -62,7 +62,7 @@ "uvicorn==0.18.2", "python-dotenv==0.19.2", "tzdata==2021.5", - "pyotp==2.9.0", + "pyotp<3", ] ), } @@ -127,10 +127,11 @@ "asgiref>=3.4.1,<4", "typing_extensions>=4.1.1,<5.0.0", "Deprecated==1.2.13", - "phonenumbers==8.13.47", - "twilio==9.3.3", + "phonenumbers<9", + "twilio<10", "aiosmtplib>=1.1.6,<4.0.0", "pkce==1.0.3", + "pyotp<3", ], python_requires=">=3.7", include_package_data=True, From dd360387a5d024ce9af4e9efc64b7e54b0119da8 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 25 Oct 2024 00:01:32 +0530 Subject: [PATCH 123/126] fixes dependency version --- dev-requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 7c0d82388..7e3a7d720 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -86,4 +86,4 @@ Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 pyotp==2.9.0 -aiofiles==24.1.0 \ No newline at end of file +aiofiles==23.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 0550005b3..eef32f349 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "uvicorn==0.18.2", "python-dotenv==0.19.2", "pyotp<3", - "aiofiles==24.1.0", + "aiofiles==23.2.1", ] ), "flask": ( From 4900c486087ea48f8c73118192490a0787375ea9 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 25 Oct 2024 11:59:52 +0530 Subject: [PATCH 124/126] changes to pre commit hook for debugging --- hooks/pre-commit.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hooks/pre-commit.sh b/hooks/pre-commit.sh index a18e7c13e..9d1e1bffa 100755 --- a/hooks/pre-commit.sh +++ b/hooks/pre-commit.sh @@ -21,7 +21,7 @@ then git stash push -k -u -- ${files_to_stash} >/dev/null 2>/dev/null fi -make check-lint >/dev/null 2>/dev/null +make check-lint linted=$? echo "$(tput setaf 3)* Properly linted?$(tput sgr 0)" @@ -36,7 +36,7 @@ else fi -make format >/dev/null 2>/dev/null +make format formatted=`git ls-files . --exclude-standard --others -m | wc -l` echo "$(tput setaf 3)* Properly formatted?$(tput sgr 0)" From 897588bd78fc58e3477d395ddd82a56df4773e00 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 25 Oct 2024 12:05:02 +0530 Subject: [PATCH 125/126] more fixes --- .../recipe/dashboard/api/userroles/get_role_to_user.py | 4 ++-- .../api/userroles/permissions/get_permissions_for_role.py | 4 ++-- .../recipe/dashboard/api/userroles/roles/get_all_roles.py | 4 ++-- supertokens_python/recipe/dashboard/utils.py | 3 ++- tests/auth-react/django3x/mysite/store.py | 3 ++- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py index 2cbd79bbc..3006c1b97 100644 --- a/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py +++ b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Union, List from typing_extensions import Literal from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions @@ -8,7 +8,7 @@ class OkResponse(APIResponse): - def __init__(self, roles: list[str]): + def __init__(self, roles: List[str]): self.status: Literal["OK"] = "OK" self.roles = roles diff --git a/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py index 670c98f5f..bab61e045 100644 --- a/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py +++ b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict, Union, List from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions from supertokens_python.recipe.userroles.asyncio import get_permissions_for_role @@ -12,7 +12,7 @@ class OkPermissionsForRoleResponse(APIResponse): - def __init__(self, permissions: list[str]): + def __init__(self, permissions: List[str]): self.status = "OK" self.permissions = permissions diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py index 35fca6842..71dc11d21 100644 --- a/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Union, List from typing_extensions import Literal from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions from supertokens_python.recipe.userroles.asyncio import get_all_roles @@ -7,7 +7,7 @@ class OkResponse(APIResponse): - def __init__(self, roles: list[str]): + def __init__(self, roles: List[str]): self.status: Literal["OK"] = "OK" self.roles = roles diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index ba7eae3bf..e05e53848 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -13,7 +13,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List, Literal +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List +from typing_extensions import Literal from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe if TYPE_CHECKING: diff --git a/tests/auth-react/django3x/mysite/store.py b/tests/auth-react/django3x/mysite/store.py index 07da199ba..3f2ece28c 100644 --- a/tests/auth-react/django3x/mysite/store.py +++ b/tests/auth-react/django3x/mysite/store.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal latest_url_with_token = "" From 04ffcf668faf49f62da56c0f7cb355e2a1b10376 Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 25 Oct 2024 12:08:39 +0530 Subject: [PATCH 126/126] undoes temp change for pre commit hook --- hooks/pre-commit.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hooks/pre-commit.sh b/hooks/pre-commit.sh index 9d1e1bffa..a18e7c13e 100755 --- a/hooks/pre-commit.sh +++ b/hooks/pre-commit.sh @@ -21,7 +21,7 @@ then git stash push -k -u -- ${files_to_stash} >/dev/null 2>/dev/null fi -make check-lint +make check-lint >/dev/null 2>/dev/null linted=$? echo "$(tput setaf 3)* Properly linted?$(tput sgr 0)" @@ -36,7 +36,7 @@ else fi -make format +make format >/dev/null 2>/dev/null formatted=`git ls-files . --exclude-standard --others -m | wc -l` echo "$(tput setaf 3)* Properly formatted?$(tput sgr 0)"