diff --git a/.circleci/setupAndTestBackendSDKWithFreeCore.sh b/.circleci/setupAndTestBackendSDKWithFreeCore.sh index c301a0c18..d774003dd 100755 --- a/.circleci/setupAndTestBackendSDKWithFreeCore.sh +++ b/.circleci/setupAndTestBackendSDKWithFreeCore.sh @@ -56,7 +56,6 @@ ST_CONNECTION_URI=http://localhost:8081 # start test-server pushd tests/test-server -sh setup-for-test.sh SUPERTOKENS_ENV=testing API_PORT=$API_PORT ST_CONNECTION_URI=$ST_CONNECTION_URI python3 app.py & popd diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 000000000..89a708854 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,10 @@ +You are an expert Python and Typescript developer. Your job is to convert node code (in typescript) into Python code. The python code then goes into this SDK. The python code style should keep in mind: +- Avoid using TypeDict +- Avoid using generic Dict as much as possible, except when defining the types for `user_context`. +- If a function has multiple `status` strings as outputs, then define one unique class per unique `status` string. The class name should be such that it indicates the status it is associated with. +- Variable and function names should be in snake_case. Class names in PascalCase. +- Whenever importing `Literal`, import it from `typing_extensions`, and not `types`. +- Do not use `|` for OR type, instead use `Union` +- When defining API interface functions, make sure the output classes inherit from `APIResponse` class, and that they have a `to_json` function defined whose output matches the structure of the provided Typescript code output objects. + +The semantic of the python code should be the same as what's of the provided Typescript code. \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 25629e81f..0c8c9f708 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,4 @@ [MASTER] - # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. @@ -20,11 +19,11 @@ fail-on= fail-under=10.0 # Files or directories to be skipped. They should be base names, not paths. -ignore=CVS +ignore=CVS, # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths and can be in Posix or Windows format. -ignore-paths= +ignore-paths=tests/test-server # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. @@ -116,7 +115,16 @@ disable=raw-checker-failed, global-statement, too-many-lines, duplicate-code, - too-many-return-statements + too-many-return-statements, + logging-not-lazy, + logging-fstring-interpolation, + consider-using-f-string, + consider-using-in, + no-else-return, + no-self-use, + no-else-raise, + too-many-nested-blocks, + # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..19251e616 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,90 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Flask, supertokens-website tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/frontendIntegration/flask-server/app.py", + "args": [ + "--port", + "8080" + ], + "cwd": "${workspaceFolder}/tests/frontendIntegration/flask-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true + }, + { + "name": "Python: FastAPI, supertokens-website tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/frontendIntegration/fastapi-server/app.py", + "args": [ + "--port", + "8080" + ], + "cwd": "${workspaceFolder}/tests/frontendIntegration/fastapi-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true + }, + { + "name": "Python: Flask, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/flask-server/app.py", + "args": [ + "--port", + "8083" + ], + "cwd": "${workspaceFolder}/tests/auth-react/flask-server", + "env": { + "FLASK_DEBUG": "1" + }, + "jinja": true + }, + { + "name": "Python: FastAPI, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/fastapi-server/app.py", + "args": [ + "--port", + "8083" + ], + "cwd": "${workspaceFolder}/tests/auth-react/fastapi-server", + "jinja": true + }, + { + "name": "Python: Django, supertokens-auth-react tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/auth-react/django3x/manage.py", + "args": [ + "runserver", + "0.0.0.0:8083" + ], + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "cwd": "${workspaceFolder}/tests/auth-react/django3x", + "jinja": true + }, + { + "name": "Python: backend-sdk-testing repo", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/test-server/app.py", + "cwd": "${workspaceFolder}/tests/test-server", + "env": { + "SUPERTOKENS_ENV": "testing", + "API_PORT": "3030", + "ST_CONNECTION_URI": "http://localhost:8081" + }, + "console": "integratedTerminal" + }, + ] +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cfe8551d..c8f65c520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,14 +8,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] + +## [0.25.0] - 2024-09-18 + +### Breaking changes +- `supertokens_python.recipe.emailverification.types.User` has been renamed to `supertokens_python.recipe.emailverification.types.EmailVerificationUser` +- The user object has been changed to be a global one, containing information about all emails, phone numbers, third party info and login methods associated with that user. +- Type of `get_email_for_user_id` in `emailverification.init` has changed +- Session recipe's error handlers take an extra param of recipe_user_id as well +- Session recipe, removes `validate_claims_in_jwt_payload` that is exposed to the user. +- TODO.. + ## [0.24.4] - 2024-10-16 - Updates `phonenumbers` and `twilio` to latest versions +>>>>>>> 0.24 ## [0.24.3] - 2024-09-24 - Adds support for form field related improvements by making fields accept any type of values - Adds support for optional fields to properly optional +>>>>>>> 0.24 ### Migration diff --git a/coreDriverInterfaceSupported.json b/coreDriverInterfaceSupported.json index 7ebde0bbf..11cc869d7 100644 --- a/coreDriverInterfaceSupported.json +++ b/coreDriverInterfaceSupported.json @@ -1,6 +1,6 @@ { "_comment": "contains a list of core-driver interfaces branch names that this core supports", "versions": [ - "3.1" + "5.1" ] -} +} \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt index f959daf22..7e3a7d720 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -85,3 +85,5 @@ uvicorn==0.18.2 Werkzeug==2.0.3 wrapt==1.13.3 zipp==3.7.0 +pyotp==2.9.0 +aiofiles==23.2.1 \ No newline at end of file diff --git a/frontendDriverInterfaceSupported.json b/frontendDriverInterfaceSupported.json index 0d7a266f3..71d6714d4 100644 --- a/frontendDriverInterfaceSupported.json +++ b/frontendDriverInterfaceSupported.json @@ -2,6 +2,10 @@ "_comment": "contains a list of frontend-driver interfaces branch names that this core supports", "versions": [ "1.17", - "2.0" + "1.18", + "1.19", + "2.0", + "3.0", + "3.1" ] } \ No newline at end of file diff --git a/setup.py b/setup.py index 201d5235b..eef32f349 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,8 @@ "Fastapi", "uvicorn==0.18.2", "python-dotenv==0.19.2", + "pyotp<3", + "aiofiles==23.2.1", ] ), "flask": ( @@ -26,6 +28,7 @@ "flask_cors", "Flask", "python-dotenv==0.19.2", + "pyotp<3", ] ), "django": ( @@ -35,6 +38,7 @@ "django-stubs==1.9.0", "uvicorn==0.18.2", "python-dotenv==0.19.2", + "pyotp<3", ] ), "django2x": ( @@ -44,6 +48,7 @@ "django-stubs==1.9.0", "gunicorn==20.1.0", "python-dotenv==0.19.2", + "pyotp<3", ] ), "drf": ( @@ -57,6 +62,7 @@ "uvicorn==0.18.2", "python-dotenv==0.19.2", "tzdata==2021.5", + "pyotp<3", ] ), } @@ -83,7 +89,7 @@ setup( name="supertokens_python", - version="0.24.4", + version="0.25.0", author="SuperTokens", license="Apache 2.0", author_email="team@supertokens.com", @@ -121,10 +127,11 @@ "asgiref>=3.4.1,<4", "typing_extensions>=4.1.1,<5.0.0", "Deprecated==1.2.13", - "phonenumbers==8.13.47", - "twilio==9.3.3", + "phonenumbers<9", + "twilio<10", "aiosmtplib>=1.1.6,<4.0.0", "pkce==1.0.3", + "pyotp<3", ], python_requires=">=3.7", include_package_data=True, diff --git a/supertokens_python/__init__.py b/supertokens_python/__init__.py index ea7ba506a..43d3573b2 100644 --- a/supertokens_python/__init__.py +++ b/supertokens_python/__init__.py @@ -17,6 +17,7 @@ from typing_extensions import Literal from supertokens_python.framework.request import BaseRequest +from supertokens_python.types import RecipeUserId from . import supertokens from .recipe_module import RecipeModule @@ -49,3 +50,7 @@ def get_request_from_user_context( user_context: Optional[Dict[str, Any]], ) -> Optional[BaseRequest]: return Supertokens.get_instance().get_request_from_user_context(user_context) + + +def convert_to_recipe_user_id(user_id: str) -> RecipeUserId: + return RecipeUserId(user_id) diff --git a/supertokens_python/asyncio/__init__.py b/supertokens_python/asyncio/__init__.py index 59796af43..17b83c1d1 100644 --- a/supertokens_python/asyncio/__init__.py +++ b/supertokens_python/asyncio/__init__.py @@ -24,7 +24,9 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import UsersResponse +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.interfaces import GetUsersResult +from supertokens_python.types import AccountInfo, User async def get_users_oldest_first( @@ -34,15 +36,17 @@ async def get_users_oldest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: - return await Supertokens.get_instance().get_users( +) -> GetUsersResult: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( tenant_id, - "ASC", - limit, - pagination_token, - include_recipe_ids, - query, - user_context, + time_joined_order="ASC", + limit=limit, + pagination_token=pagination_token, + include_recipe_ids=include_recipe_ids, + query=query, + user_context=user_context, ) @@ -53,15 +57,17 @@ async def get_users_newest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: - return await Supertokens.get_instance().get_users( +) -> GetUsersResult: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_users( tenant_id, - "DESC", - limit, - pagination_token, - include_recipe_ids, - query, - user_context, + time_joined_order="DESC", + limit=limit, + pagination_token=pagination_token, + include_recipe_ids=include_recipe_ids, + query=query, + user_context=user_context, ) @@ -76,9 +82,27 @@ async def get_user_count( async def delete_user( - user_id: str, user_context: Optional[Dict[str, Any]] = None + user_id: str, + remove_all_linked_accounts: bool = True, + user_context: Optional[Dict[str, Any]] = None, ) -> None: - return await Supertokens.get_instance().delete_user(user_id, user_context) + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.delete_user( + user_id, + remove_all_linked_accounts=remove_all_linked_accounts, + user_context=user_context, + ) + + +async def get_user( + user_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Optional[User]: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + user_id=user_id, user_context=user_context + ) async def create_user_id_mapping( @@ -131,3 +155,19 @@ async def update_or_delete_user_id_mapping_info( return await Supertokens.get_instance().update_or_delete_user_id_mapping_info( user_id, user_id_type, external_user_id_info, user_context ) + + +async def list_users_by_account_info( + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool = False, + user_context: Optional[Dict[str, Any]] = None, +) -> List[User]: + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id, + account_info, + do_union_of_account_info, + user_context, + ) diff --git a/supertokens_python/auth_utils.py b/supertokens_python/auth_utils.py new file mode 100644 index 000000000..fe5a03eaf --- /dev/null +++ b/supertokens_python/auth_utils.py @@ -0,0 +1,998 @@ +from typing import Awaitable, Callable, Dict, Any, Optional, Union, List +from typing_extensions import Literal +from supertokens_python.framework import BaseRequest +from supertokens_python.recipe.accountlinking import ( + AccountInfoWithRecipeIdAndUserId, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.accountlinking.utils import ( + recipe_init_defined_should_do_automatic_account_linking, +) +from supertokens_python.recipe.multifactorauth.asyncio import ( + mark_factor_as_complete_in_session, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.utils import ( + is_valid_first_factor, + update_and_get_mfa_related_info_in_session, +) +from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant +from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.recipe.session.asyncio import create_new_session, get_session +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import ( + AccountInfo, + User, + LoginMethod, +) +from supertokens_python.types import ( + RecipeUserId, +) +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.emailverification import ( + EmailVerificationClaim, +) +from supertokens_python.exceptions import BadInputError, raise_bad_input_exception +from supertokens_python.utils import log_debug_message +from .asyncio import get_user + + +class LinkingToSessionUserFailedError: + status: Literal["LINKING_TO_SESSION_USER_FAILED"] = "LINKING_TO_SESSION_USER_FAILED" + reason: Literal[ + "EMAIL_VERIFICATION_REQUIRED", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "INPUT_USER_IS_NOT_A_PRIMARY_USER", + ] + + def __init__( + self, + reason: Literal[ + "EMAIL_VERIFICATION_REQUIRED", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + "INPUT_USER_IS_NOT_A_PRIMARY_USER", + ], + ): + self.reason = reason + + +class OkResponse: + status: Literal["OK"] + valid_factor_ids: List[str] + is_first_factor: bool + + def __init__(self, valid_factor_ids: List[str], is_first_factor: bool): + self.status = "OK" + self.valid_factor_ids = valid_factor_ids + self.is_first_factor = is_first_factor + + +class SignUpNotAllowedResponse: + status: Literal["SIGN_UP_NOT_ALLOWED"] = "SIGN_UP_NOT_ALLOWED" + + +class SignInNotAllowedResponse: + status: Literal["SIGN_IN_NOT_ALLOWED"] = "SIGN_IN_NOT_ALLOWED" + + +async def pre_auth_checks( + authenticating_account_info: AccountInfoWithRecipeId, + authenticating_user: Union[User, None], + tenant_id: str, + factor_ids: List[str], + is_sign_up: bool, + is_verified: bool, + sign_in_verifies_login_method: bool, + skip_session_user_update_in_core: bool, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], +) -> Union[ + OkResponse, + SignUpNotAllowedResponse, + SignInNotAllowedResponse, + LinkingToSessionUserFailedError, +]: + valid_factor_ids: List[str] = [] + + if len(factor_ids) == 0: + raise Exception( + "This should never happen: empty factorIds array passed to preSignInChecks" + ) + + log_debug_message("preAuthChecks checking auth types") + auth_type_info = await check_auth_type_and_linking_status( + session, + should_try_linking_with_session_user, + authenticating_account_info, + authenticating_user, + skip_session_user_update_in_core, + user_context, + ) + if auth_type_info.status != "OK": + log_debug_message( + f"preAuthChecks returning {auth_type_info.status} from checkAuthType results" + ) + return auth_type_info + + if auth_type_info.is_first_factor: + log_debug_message("preAuthChecks getting valid first factors") + valid_first_factors = ( + await filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids, tenant_id, session is not None, user_context + ) + ) + valid_factor_ids = valid_first_factors + else: + assert isinstance( + auth_type_info, + (OkSecondFactorNotLinkedResponse, OkSecondFactorLinkedResponse), + ) + assert session is not None + log_debug_message("preAuthChecks getting valid secondary factors") + valid_factor_ids = ( + await filter_out_invalid_second_factors_or_throw_if_all_are_invalid( + factor_ids, + auth_type_info.input_user_already_linked_to_session_user, + auth_type_info.session_user, + session, + user_context, + ) + ) + + if not is_sign_up and authenticating_user is None: + raise Exception( + "This should never happen: preAuthChecks called with isSignUp: false, authenticatingUser: None" + ) + + if is_sign_up: + verified_in_session_user = not isinstance( + auth_type_info, OkFirstFactorResponse + ) and any( + lm.verified + and ( + lm.has_same_email_as(authenticating_account_info.email) + or lm.has_same_phone_number_as(authenticating_account_info.phone_number) + ) + for lm in auth_type_info.session_user.login_methods + ) + + log_debug_message("preAuthChecks checking if the user is allowed to sign up") + if not await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=authenticating_account_info, + is_verified=is_verified + or sign_in_verifies_login_method + or verified_in_session_user, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ): + return SignUpNotAllowedResponse() + elif authenticating_user is not None: + log_debug_message("preAuthChecks checking if the user is allowed to sign in") + if not await AccountLinkingRecipe.get_instance().is_sign_in_allowed( + user=authenticating_user, + account_info=authenticating_account_info, + sign_in_verifies_login_method=sign_in_verifies_login_method, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ): + return SignInNotAllowedResponse() + + log_debug_message("preAuthChecks returning OK") + return OkResponse( + valid_factor_ids=valid_factor_ids, + is_first_factor=auth_type_info.is_first_factor, + ) + + +class PostAuthChecksOkResponse: + status: Literal["OK"] + session: SessionContainer + user: User + + def __init__(self, status: Literal["OK"], session: SessionContainer, user: User): + self.status = status + self.session = session + self.user = user + + +class PostAuthChecksSignInNotAllowedResponse: + status: Literal["SIGN_IN_NOT_ALLOWED"] + + +async def post_auth_checks( + authenticated_user: User, + recipe_user_id: RecipeUserId, + is_sign_up: bool, + factor_id: str, + session: Union[SessionContainer, None], + tenant_id: str, + user_context: Dict[str, Any], + request: BaseRequest, +) -> Union[PostAuthChecksOkResponse, PostAuthChecksSignInNotAllowedResponse]: + log_debug_message( + f"postAuthChecks called {'with' if session is not None else 'without'} a session to " + f"{'sign up' if is_sign_up else 'sign in'} with {factor_id}" + ) + + mfa_instance = MultiFactorAuthRecipe.get_instance() + + resp_session = session + if session is not None: + authenticated_user_linked_to_session_user = any( + lm.recipe_user_id.get_as_string() + == session.get_recipe_user_id(user_context).get_as_string() + for lm in authenticated_user.login_methods + ) + if authenticated_user_linked_to_session_user: + log_debug_message("postAuthChecks session and input user got linked") + if mfa_instance is not None: + log_debug_message("postAuthChecks marking factor as completed") + # if the authenticating user is linked to the current session user (it means that the factor got set up or completed), + # we mark it as completed in the session. + assert resp_session is not None + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + else: + log_debug_message("postAuthChecks checking overwriteSessionDuringSignInUp") + # If the new user wasn't linked to the current one, we check the config and overwrite the session if required + # Note: we could also get here if MFA is enabled, but the app didn't want to link the user to the session user. + # This is intentional, since the MFA and overwriteSessionDuringSignInUp configs should work independently. + resp_session = await create_new_session( + request, tenant_id, recipe_user_id, {}, {}, user_context + ) + if mfa_instance is not None: + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + else: + log_debug_message("postAuthChecks creating session for first factor sign in/up") + # If there is no input session, we do not need to do anything other checks and create a new session + resp_session = await create_new_session( + request, tenant_id, recipe_user_id, {}, {}, user_context + ) + + # Here we can always mark the factor as completed, since we just created the session + if mfa_instance is not None: + await mark_factor_as_complete_in_session( + resp_session, factor_id, user_context + ) + + assert resp_session is not None + return PostAuthChecksOkResponse( + status="OK", session=resp_session, user=authenticated_user + ) + + +class AuthenticatingUserInfo: + def __init__(self, user: User, login_method: Union[LoginMethod, None]): + self.user = user + self.login_method = login_method + + +async def get_authenticating_user_and_add_to_current_tenant_if_required( + recipe_id: str, + email: Optional[str], + phone_number: Optional[str], + third_party: Optional[ThirdPartyInfo], + tenant_id: str, + session: Optional[SessionContainer], + check_credentials_on_tenant: Callable[[str], Awaitable[bool]], + user_context: Dict[str, Any], +) -> Optional[AuthenticatingUserInfo]: + i = 0 + while i < 300: + account_info = { + "email": email, + "phoneNumber": phone_number, + "thirdParty": third_party, + } + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired called with {account_info}" + ) + existing_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo( + email=email, phone_number=phone_number, third_party=third_party + ), + do_union_of_account_info=True, + user_context=user_context, + ) + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired got {len(existing_users)} users from the core resp" + ) + users_with_matching_login_methods = [ + AuthenticatingUserInfo( + user=user, + login_method=next( + ( + lm + for lm in user.login_methods + if lm.recipe_id == recipe_id + and ( + (email is not None and lm.has_same_email_as(email)) + or lm.has_same_phone_number_as(phone_number) + or lm.has_same_third_party_info_as(third_party) + ) + ), + None, + ), + ) + for user in existing_users + ] + users_with_matching_login_methods = [ + u for u in users_with_matching_login_methods if u.login_method is not None + ] + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired got {len(users_with_matching_login_methods)} users with matching login methods" + ) + if len(users_with_matching_login_methods) > 1: + raise Exception( + "You have found a bug. Please report it on https://github.com/supertokens/supertokens-node/issues" + ) + authenticating_user = ( + AuthenticatingUserInfo( + users_with_matching_login_methods[0].user, + users_with_matching_login_methods[0].login_method, + ) + if users_with_matching_login_methods + else None + ) + + if authenticating_user is None and session is not None: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired checking session user" + ) + session_user = await get_user( + session.get_user_id(user_context), user_context + ) + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + if not session_user.is_primary_user: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired session user is non-primary so returning early without checking other tenants" + ) + return None + + matching_login_methods_from_session_user = [ + lm + for lm in session_user.login_methods + if lm.recipe_id == recipe_id + and ( + lm.has_same_email_as(email) + or lm.has_same_phone_number_as(phone_number) + or lm.has_same_third_party_info_as(third_party) + ) + ] + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session has {len(matching_login_methods_from_session_user)} matching login methods" + ) + + if any( + tenant_id in lm.tenant_ids + for lm in matching_login_methods_from_session_user + ): + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session has {len(matching_login_methods_from_session_user)} matching login methods" + ) + return AuthenticatingUserInfo( + user=session_user, + login_method=next( + lm + for lm in matching_login_methods_from_session_user + if tenant_id in lm.tenant_ids + ), + ) + + go_to_retry = False + for lm in matching_login_methods_from_session_user: + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired session checking credentials on {lm.tenant_ids[0]}" + ) + if await check_credentials_on_tenant(lm.tenant_ids[0]): + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired associating user from {lm.tenant_ids[0]} with current tenant" + ) + associate_res = await associate_user_to_tenant( + tenant_id, lm.recipe_user_id, user_context + ) + log_debug_message( + f"getAuthenticatingUserAndAddToCurrentTenantIfRequired associating returned {associate_res.status}" + ) + if associate_res.status == "OK": + lm.tenant_ids.append(tenant_id) + return AuthenticatingUserInfo( + user=session_user, login_method=lm + ) + if associate_res.status in [ + "UNKNOWN_USER_ID_ERROR", + "EMAIL_ALREADY_EXISTS_ERROR", + "PHONE_NUMBER_ALREADY_EXISTS_ERROR", + "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR", + ]: + go_to_retry = True + break + if associate_res.status == "ASSOCIATION_NOT_ALLOWED_ERROR": + raise UnauthorisedError( + "Session user not associated with the session tenant" + ) + if go_to_retry: + log_debug_message( + "getAuthenticatingUserAndAddToCurrentTenantIfRequired retrying" + ) + i += 1 + continue + return authenticating_user + raise Exception( + "This should never happen: ran out of retries for getAuthenticatingUserAndAddToCurrentTenantIfRequired" + ) + + +class OkFirstFactorResponse: + status: Literal["OK"] = "OK" + is_first_factor: Literal[True] = True + + +class OkSecondFactorLinkedResponse: + status: Literal["OK"] = "OK" + is_first_factor: Literal[False] = False + input_user_already_linked_to_session_user: Literal[True] = True + session_user: User + + def __init__(self, session_user: User): + self.session_user = session_user + + +class OkSecondFactorNotLinkedResponse: + status: Literal["OK"] = "OK" + is_first_factor: Literal[False] = False + input_user_already_linked_to_session_user: Literal[False] = False + session_user: User + linking_to_session_user_requires_verification: bool + + def __init__( + self, + session_user: User, + linking_to_session_user_requires_verification: bool, + ): + self.session_user = session_user + self.linking_to_session_user_requires_verification = ( + linking_to_session_user_requires_verification + ) + + +async def check_auth_type_and_linking_status( + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + account_info: AccountInfoWithRecipeId, + input_user: Union[User, None], + skip_session_user_update_in_core: bool, + user_context: Dict[str, Any], +) -> Union[ + OkFirstFactorResponse, + OkSecondFactorLinkedResponse, + OkSecondFactorNotLinkedResponse, + LinkingToSessionUserFailedError, +]: + log_debug_message("check_auth_type_and_linking_status called") + session_user: Union[User, None] = None + if session is None: + if should_try_linking_with_session_user is True: + raise UnauthorisedError( + "Session not found but shouldTryLinkingWithSessionUser is true" + ) + log_debug_message( + "check_auth_type_and_linking_status returning first factor because there is no session" + ) + return OkFirstFactorResponse() + else: + if should_try_linking_with_session_user is False: + # In our normal flows this should never happen - but some user overrides might do this. + # Anyway, since should_try_linking_with_session_user explicitly set to false, it's safe to consider this a first factor + log_debug_message( + "check_auth_type_and_linking_status returning first factor because should_try_linking_with_session_user is False" + ) + return OkFirstFactorResponse() + if not recipe_init_defined_should_do_automatic_account_linking(): + if should_try_linking_with_session_user is True: + raise Exception( + "Please initialise the account linking recipe and define should_do_automatic_account_linking to enable MFA" + ) + else: + if MultiFactorAuthRecipe.get_instance() is not None: + raise Exception( + "Please initialise the account linking recipe and define should_do_automatic_account_linking to enable MFA" + ) + else: + return OkFirstFactorResponse() + + if input_user is not None and input_user.id == session.get_user_id(): + log_debug_message( + "check_auth_type_and_linking_status returning secondary factor, session and input user are the same" + ) + return OkSecondFactorLinkedResponse( + session_user=input_user, + ) + + log_debug_message( + f"check_auth_type_and_linking_status loading session user, {input_user.id if input_user else None} === {session.get_user_id()}" + ) + session_user_result = await try_and_make_session_user_into_a_primary_user( + session, skip_session_user_update_in_core, user_context + ) + if session_user_result.status == "SHOULD_AUTOMATICALLY_LINK_FALSE": + if should_try_linking_with_session_user is True: + raise BadInputError( + "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + ) + return OkFirstFactorResponse() + elif ( + session_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkingToSessionUserFailedError( + reason="SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + session_user = session_user_result.user + + should_link = await AccountLinkingRecipe.get_instance().config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + session_user, + session, + session.get_tenant_id(), + user_context, + ) + log_debug_message( + f"check_auth_type_and_linking_status session user <-> input user should_do_automatic_account_linking returned {should_link}" + ) + + if isinstance(should_link, ShouldNotAutomaticallyLink): + if should_try_linking_with_session_user is True: + raise BadInputError( + "should_do_automatic_account_linking returned false when creating primary user but shouldTryLinkingWithSessionUser is true" + ) + return OkFirstFactorResponse() + else: + return OkSecondFactorNotLinkedResponse( + session_user=session_user, + linking_to_session_user_requires_verification=should_link.should_require_verification, + ) + + +class OkResponse2: + status: Literal["OK"] + user: User + + def __init__(self, user: User): + self.status = "OK" + self.user: User = user + + +async def link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id: str, + input_user: User, + recipe_user_id: RecipeUserId, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], +) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info called" + ) + + async def retry(): + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info retrying...." + ) + return await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=input_user, + session=session, + recipe_user_id=recipe_user_id, + should_try_linking_with_session_user=should_try_linking_with_session_user, + user_context=user_context, + ) + + auth_login_method = next( + ( + lm + for lm in input_user.login_methods + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + None, + ) + if auth_login_method is None: + raise Exception( + "This should never happen: the recipe_user_id and user is inconsistent in create_primary_user_id_or_link_by_account_info params" + ) + + auth_type_res = await check_auth_type_and_linking_status( + session, + should_try_linking_with_session_user, + AccountInfoWithRecipeId( + recipe_id=auth_login_method.recipe_id, + email=auth_login_method.email, + phone_number=auth_login_method.phone_number, + third_party=auth_login_method.third_party, + ), + input_user, + False, + user_context, + ) + + if not isinstance( + auth_type_res, + ( + OkFirstFactorResponse, + OkSecondFactorLinkedResponse, + OkSecondFactorNotLinkedResponse, + ), + ): + return LinkingToSessionUserFailedError(reason=auth_type_res.reason) + + if isinstance(auth_type_res, OkFirstFactorResponse): + if not recipe_init_defined_should_do_automatic_account_linking(): + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info skipping link by account info because this is a first factor auth and the app hasn't defined should_do_automatic_account_linking" + ) + return OkResponse2(user=input_user) + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info trying to link by account info because this is a first factor auth" + ) + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + input_user=input_user, + session=session, + tenant_id=tenant_id, + user_context=user_context, + ) + if link_res.status == "OK": + assert link_res.user is not None + return OkResponse2(user=link_res.user) + if link_res.status == "NO_LINK": + return OkResponse2(user=input_user) + return await retry() + + if isinstance(auth_type_res, OkSecondFactorLinkedResponse): + return OkResponse2(user=auth_type_res.session_user) + + log_debug_message( + "link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info trying to link by session info" + ) + session_linking_res = await try_linking_by_session( + session_user=auth_type_res.session_user, + authenticated_user=input_user, + auth_login_method=auth_login_method, + linking_to_session_user_requires_verification=auth_type_res.linking_to_session_user_requires_verification, + user_context=user_context, + ) + if isinstance(session_linking_res, LinkingToSessionUserFailedError): + if session_linking_res.reason == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return await retry() + else: + return session_linking_res + else: + return session_linking_res + + +class ShouldAutomaticallyLinkFalseResponse: + status: Literal["SHOULD_AUTOMATICALLY_LINK_FALSE"] + + def __init__(self): + self.status = "SHOULD_AUTOMATICALLY_LINK_FALSE" + + +class AccountInfoAlreadyAssociatedResponse: + status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] + + def __init__(self): + self.status = ( + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + +async def try_and_make_session_user_into_a_primary_user( + session: SessionContainer, + skip_session_user_update_in_core: bool, + user_context: Dict[str, Any], +) -> Union[ + OkResponse2, + ShouldAutomaticallyLinkFalseResponse, + AccountInfoAlreadyAssociatedResponse, +]: + log_debug_message("try_and_make_session_user_into_a_primary_user called") + session_user = await get_user(session.get_user_id(), user_context) + if session_user is None: + raise UnauthorisedError("Session user not found") + + if session_user.is_primary_user: + log_debug_message( + "try_and_make_session_user_into_a_primary_user session user already primary" + ) + return OkResponse2(user=session_user) + else: + log_debug_message( + "try_and_make_session_user_into_a_primary_user not primary user yet" + ) + + account_linking_instance = AccountLinkingRecipe.get_instance() + should_do_account_linking = ( + await account_linking_instance.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + session_user.login_methods[0] + ), + None, + session, + session.get_tenant_id(), + user_context, + ) + ) + log_debug_message( + f"try_and_make_session_user_into_a_primary_user should_do_account_linking: {should_do_account_linking}" + ) + + if isinstance(should_do_account_linking, ShouldAutomaticallyLink): + if skip_session_user_update_in_core: + return OkResponse2(user=session_user) + if ( + should_do_account_linking.should_require_verification + and not session_user.login_methods[0].verified + ): + if ( + await session.get_claim_value(EmailVerificationClaim, user_context) + ) is not False: + log_debug_message( + "try_and_make_session_user_into_a_primary_user updating emailverification status in session" + ) + await session.set_claim_value( + EmailVerificationClaim, False, user_context + ) + log_debug_message( + "try_and_make_session_user_into_a_primary_user throwing validation error" + ) + await session.assert_claims( + [EmailVerificationClaim.validators.is_verified()], user_context + ) + raise Exception( + "This should never happen: email verification claim validator passed after setting value to false" + ) + create_primary_user_res = await account_linking_instance.recipe_implementation.create_primary_user( + recipe_user_id=session_user.login_methods[0].recipe_user_id, + user_context=user_context, + ) + log_debug_message( + f"try_and_make_session_user_into_a_primary_user create_primary_user returned {create_primary_user_res.status}" + ) + if ( + create_primary_user_res.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + raise UnauthorisedError("Session user not found") + elif create_primary_user_res.status == "OK": + return OkResponse2(user=create_primary_user_res.user) + else: + return AccountInfoAlreadyAssociatedResponse() + else: + return ShouldAutomaticallyLinkFalseResponse() + + +async def try_linking_by_session( + linking_to_session_user_requires_verification: bool, + auth_login_method: LoginMethod, + authenticated_user: User, + session_user: User, + user_context: Dict[str, Any], +) -> Union[OkResponse2, LinkingToSessionUserFailedError,]: + log_debug_message("tryLinkingBySession called") + + session_user_has_verified_account_info = any( + ( + lm.has_same_email_as(auth_login_method.email) + or lm.has_same_phone_number_as(auth_login_method.phone_number) + ) + and lm.verified + for lm in session_user.login_methods + ) + + can_link_based_on_verification = ( + not linking_to_session_user_requires_verification + or auth_login_method.verified + or session_user_has_verified_account_info + ) + + if not can_link_based_on_verification: + return LinkingToSessionUserFailedError(reason="EMAIL_VERIFICATION_REQUIRED") + + link_accounts_result = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.link_accounts( + recipe_user_id=authenticated_user.login_methods[0].recipe_user_id, + primary_user_id=session_user.id, + user_context=user_context, + ) + ) + + if link_accounts_result.status == "OK": + log_debug_message( + "tryLinkingBySession successfully linked input user to session user" + ) + return OkResponse2(user=link_accounts_result.user) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - input user linked to another user" + ) + return LinkingToSessionUserFailedError( + reason="RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + elif link_accounts_result.status == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - INPUT_USER_IS_NOT_A_PRIMARY_USER, should retry" + ) + return LinkingToSessionUserFailedError( + reason="INPUT_USER_IS_NOT_A_PRIMARY_USER" + ) + else: + log_debug_message( + "tryLinkingBySession linking to session user failed because of a race condition - input user has another primary user it can be linked to" + ) + return LinkingToSessionUserFailedError( + reason="ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ) + + +async def filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids: List[str], + tenant_id: str, + has_session: bool, + user_context: Dict[str, Any], +) -> List[str]: + valid_factor_ids: List[str] = [] + for _id in factor_ids: + valid_res = await is_valid_first_factor(tenant_id, _id, user_context) + + if valid_res == "TENANT_NOT_FOUND_ERROR": + if has_session: + raise UnauthorisedError("Tenant not found") + else: + raise Exception("Tenant not found error.") + elif valid_res == "OK": + valid_factor_ids.append(_id) + + if len(valid_factor_ids) == 0: + if not has_session: + raise UnauthorisedError( + "A valid session is required to authenticate with secondary factors" + ) + else: + raise_bad_input_exception( + "First factor sign in/up called for a non-first factor with an active session. This might indicate that you are trying to use this as a secondary factor, but disabled account linking." + ) + + return valid_factor_ids + + +async def filter_out_invalid_second_factors_or_throw_if_all_are_invalid( + factor_ids: List[str], + input_user_already_linked_to_session_user: bool, + session_user: User, + session: SessionContainer, + user_context: Dict[str, Any], +) -> List[str]: + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid called for {', '.join(factor_ids)}" + ) + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + if not input_user_already_linked_to_session_user: + factors_set_up_for_user_prom: Optional[List[str]] = None + mfa_info_prom = None + + async def get_factors_set_up_for_user() -> List[str]: + nonlocal factors_set_up_for_user_prom + if factors_set_up_for_user_prom is None: + factors_set_up_for_user_prom = await mfa_instance.recipe_implementation.get_factors_setup_for_user( + user=session_user, user_context=user_context + ) + assert factors_set_up_for_user_prom is not None + return factors_set_up_for_user_prom + + async def get_mfa_requirements_for_auth(): + nonlocal mfa_info_prom + if mfa_info_prom is None: + + mfa_info_prom = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + return mfa_info_prom.mfa_requirements_for_auth + + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid checking if linking is allowed by the mfa recipe" + ) + caught_setup_factor_error: Optional[Exception] = None + valid_factor_ids: List[str] = [] + + for _id in factor_ids: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid checking assert_allowed_to_setup_factor_else_throw_invalid_claim_error" + ) + try: + await mfa_instance.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + factor_id=_id, + session=session, + factors_set_up_for_user=get_factors_set_up_for_user, + mfa_requirements_for_auth=get_mfa_requirements_for_auth, + user_context=user_context, + ) + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid {id} valid because assert_allowed_to_setup_factor_else_throw_invalid_claim_error passed" + ) + valid_factor_ids.append(_id) + except Exception as err: + log_debug_message( + f"filter_out_invalid_second_factors_or_throw_if_all_are_invalid assert_allowed_to_setup_factor_else_throw_invalid_claim_error failed for {id}" + ) + caught_setup_factor_error = err + + if len(valid_factor_ids) == 0: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid rethrowing error from assert_allowed_to_setup_factor_else_throw_invalid_claim_error because we found no valid factors" + ) + if caught_setup_factor_error is not None: + raise caught_setup_factor_error + else: + raise Exception("Should never come here") + + return valid_factor_ids + else: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid allowing all factors because it'll not create new link" + ) + return factor_ids + else: + log_debug_message( + "filter_out_invalid_second_factors_or_throw_if_all_are_invalid allowing all factors because MFA is not enabled" + ) + return factor_ids + + +def is_fake_email(email: str) -> bool: + return email.endswith("@stfakeemail.supertokens.com") or email.endswith( + ".fakeemail.com" + ) # .fakeemail.com for older users + + +async def load_session_in_auth_api_if_needed( + request: BaseRequest, + should_try_linking_with_session_user: Optional[bool], + user_context: Dict[str, Any], +) -> Optional[SessionContainer]: + + if should_try_linking_with_session_user is not False: + return await get_session( + request, + session_required=should_try_linking_with_session_user is True, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + return None diff --git a/supertokens_python/constants.py b/supertokens_python/constants.py index 067c74e31..62464cd86 100644 --- a/supertokens_python/constants.py +++ b/supertokens_python/constants.py @@ -11,10 +11,11 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + from __future__ import annotations -SUPPORTED_CDI_VERSIONS = ["3.0"] -VERSION = "0.24.4" +SUPPORTED_CDI_VERSIONS = ["5.1"] +VERSION = "0.25.0" TELEMETRY = "/telemetry" USER_COUNT = "/users/count" USER_DELETE = "/user/remove" @@ -27,6 +28,6 @@ FDI_KEY_HEADER = "fdi-version" API_VERSION = "/apiversion" API_VERSION_HEADER = "cdi-version" -DASHBOARD_VERSION = "0.7" +DASHBOARD_VERSION = "0.13" ONE_YEAR_IN_MS = 31536000000 RATE_LIMIT_STATUS_CODE = 429 diff --git a/supertokens_python/normalised_url_domain.py b/supertokens_python/normalised_url_domain.py index ffd57a738..1309f234b 100644 --- a/supertokens_python/normalised_url_domain.py +++ b/supertokens_python/normalised_url_domain.py @@ -11,31 +11,29 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import annotations +from __future__ import annotations from typing import TYPE_CHECKING from urllib.parse import urlparse - from .utils import is_an_ip_address +from .exceptions import raise_general_exception if TYPE_CHECKING: pass -from .exceptions import raise_general_exception class NormalisedURLDomain: def __init__(self, url: str): - self.__value = normalise_domain_path_or_throw_error(url) + self.__value = normalise_url_domain_or_throw_error(url) def get_as_string_dangerous(self): return self.__value -def normalise_domain_path_or_throw_error( +def normalise_url_domain_or_throw_error( input_str: str, ignore_protocol: bool = False ) -> str: input_str = input_str.strip().lower() - try: if ( (not input_str.startswith("http://")) @@ -44,7 +42,6 @@ def normalise_domain_path_or_throw_error( ): raise Exception("converting to proper URL") url_obj = urlparse(input_str) - if ignore_protocol: if url_obj.hostname is None: raise Exception("Should never come here") @@ -56,17 +53,13 @@ def normalise_domain_path_or_throw_error( input_str = "https://" + url_obj.netloc else: input_str = url_obj.scheme + "://" + url_obj.netloc - return input_str except Exception: pass - if input_str.startswith("/"): raise_general_exception("Please provide a valid domain name") - if input_str.startswith("."): input_str = input_str[1:] - if ( ("." in input_str or input_str.startswith("localhost")) and (not input_str.startswith("http://")) @@ -75,7 +68,7 @@ def normalise_domain_path_or_throw_error( input_str = "https://" + input_str try: urlparse(input_str) - return normalise_domain_path_or_throw_error(input_str, True) + return normalise_url_domain_or_throw_error(input_str, True) except Exception: pass raise_general_exception("Please provide a valid domain name") diff --git a/supertokens_python/normalised_url_path.py b/supertokens_python/normalised_url_path.py index a99142415..31075a303 100644 --- a/supertokens_python/normalised_url_path.py +++ b/supertokens_python/normalised_url_path.py @@ -13,13 +13,12 @@ # under the License. from __future__ import annotations - from typing import TYPE_CHECKING from urllib.parse import urlparse +from .exceptions import raise_general_exception if TYPE_CHECKING: pass -from .exceptions import raise_general_exception class NormalisedURLPath: @@ -40,33 +39,26 @@ def equals(self, other: NormalisedURLPath) -> bool: def is_a_recipe_path(self) -> bool: parts = self.__value.split("/") - return (len(parts) > 1 and parts[1] == "recipe") or ( - len(parts) > 2 and parts[2] == "recipe" - ) + return parts[1] == "recipe" or (len(parts) > 2 and parts[2] == "recipe") def normalise_url_path_or_throw_error(input_str: str) -> str: input_str = input_str.strip().lower() - try: - if (not input_str.startswith("http://")) and ( - not input_str.startswith("https://") - ): + if not input_str.startswith("http://") and not input_str.startswith("https://"): raise Exception("converting to proper URL") url_obj = urlparse(input_str) input_str = url_obj.path - if input_str.endswith("/"): return input_str[:-1] - return input_str except Exception: pass if ( (domain_given(input_str) or input_str.startswith("localhost")) - and (not input_str.startswith("http://")) - and (not input_str.startswith("https://")) + and not input_str.startswith("http://") + and not input_str.startswith("https://") ): input_str = "http://" + input_str return normalise_url_path_or_throw_error(input_str) @@ -82,23 +74,18 @@ def normalise_url_path_or_throw_error(input_str: str) -> str: def domain_given(input_str: str) -> bool: - if ("." not in input_str) or (input_str.startswith("/")): + if "." not in input_str or input_str.startswith("/"): return False - try: + if not "http://" in input_str and not "https://" in input_str: + raise Exception("Trying with http") url = urlparse(input_str) - if url.hostname is None: - raise Exception("Should never come here") - return url.hostname.find(".") != -1 + return url.hostname is not None and "." in url.hostname except Exception: pass - try: url = urlparse("http://" + input_str) - if url.hostname is None: - raise Exception("Should never come here") - return url.hostname.find(".") != -1 + return url.hostname is not None and "." in url.hostname except Exception: pass - return False diff --git a/supertokens_python/post_init_callbacks.py b/supertokens_python/post_init_callbacks.py index 227acb78e..ddbf0afa9 100644 --- a/supertokens_python/post_init_callbacks.py +++ b/supertokens_python/post_init_callbacks.py @@ -18,15 +18,18 @@ class PostSTInitCallbacks: """Callbacks that are called after the SuperTokens instance is initialized.""" - callbacks: List[Callable[[], None]] = [] + post_init_callbacks: List[Callable[[], None]] = [] @staticmethod def add_post_init_callback(cb: Callable[[], None]) -> None: - PostSTInitCallbacks.callbacks.append(cb) + PostSTInitCallbacks.post_init_callbacks.append(cb) @staticmethod def run_post_init_callbacks() -> None: - for cb in PostSTInitCallbacks.callbacks: + for cb in PostSTInitCallbacks.post_init_callbacks: cb() + PostSTInitCallbacks.post_init_callbacks = [] - PostSTInitCallbacks.callbacks = [] + @staticmethod + def reset(): + PostSTInitCallbacks.post_init_callbacks = [] diff --git a/supertokens_python/process_state.py b/supertokens_python/process_state.py index e50fb95b2..83429697c 100644 --- a/supertokens_python/process_state.py +++ b/supertokens_python/process_state.py @@ -12,22 +12,27 @@ # License for the specific language governing permissions and limitations # under the License. from os import environ -from typing import List +from typing import List, Optional from enum import Enum -class AllowedProcessStates(Enum): - CALLING_SERVICE_IN_VERIFY = 1 - CALLING_SERVICE_IN_GET_HANDSHAKE_INFO = 2 - CALLING_SERVICE_IN_GET_API_VERSION = 3 - CALLING_SERVICE_IN_REQUEST_HELPER = 4 +class PROCESS_STATE(Enum): + CALLING_SERVICE_IN_VERIFY = 0 + CALLING_SERVICE_IN_GET_API_VERSION = 1 + CALLING_SERVICE_IN_REQUEST_HELPER = 2 + MULTI_JWKS_VALIDATION = 3 + IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS = 4 + IS_SIGN_UP_ALLOWED_CALLED = 5 + IS_SIGN_IN_ALLOWED_CALLED = 6 + IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED = 7 + ADDING_NO_CACHE_HEADER_IN_FETCH = 8 class ProcessState: __instance = None def __init__(self): - self.history: List[AllowedProcessStates] = [] + self.history: List[PROCESS_STATE] = [] @staticmethod def get_instance(): @@ -35,9 +40,37 @@ def get_instance(): ProcessState.__instance = ProcessState() return ProcessState.__instance - def add_state(self, state: AllowedProcessStates): + def add_state(self, state: PROCESS_STATE): if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): self.history.append(state) def reset(self): self.history = [] + + def get_event_by_last_event_by_name( + self, state: PROCESS_STATE + ) -> Optional[PROCESS_STATE]: + for event in reversed(self.history): + if event == state: + return event + return None + + def wait_for_event( + self, state: PROCESS_STATE, time_in_ms: int = 7000 + ) -> Optional[PROCESS_STATE]: + from time import time, sleep + + start_time = time() + + def try_and_get() -> Optional[PROCESS_STATE]: + result = self.get_event_by_last_event_by_name(state) + if result is None: + if (time() - start_time) * 1000 > time_in_ms: + return None + else: + sleep(1) + return try_and_get() + else: + return result + + return try_and_get() diff --git a/supertokens_python/querier.py b/supertokens_python/querier.py index 6382f4cd1..69945493d 100644 --- a/supertokens_python/querier.py +++ b/supertokens_python/querier.py @@ -35,7 +35,7 @@ from typing import List, Set, Union -from .process_state import AllowedProcessStates, ProcessState +from .process_state import PROCESS_STATE, ProcessState from .utils import find_max_version, is_4xx_error, is_5xx_error from sniffio import AsyncLibraryNotFoundError from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop @@ -74,7 +74,6 @@ class Querier: def __init__(self, hosts: List[Host], rid_to_core: Union[None, str] = None): self.__hosts = hosts self.__rid_to_core = None - self.__global_cache_tag = get_timestamp_ms() if rid_to_core is not None: self.__rid_to_core = rid_to_core @@ -131,7 +130,7 @@ async def get_api_version(self, user_context: Union[Dict[str, Any], None] = None return Querier.api_version ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION + PROCESS_STATE.CALLING_SERVICE_IN_GET_API_VERSION ) async def f(url: str, method: str) -> Response: @@ -277,7 +276,7 @@ async def f(url: str, method: str) -> Response: if user_context is not None: if ( user_context.get("_default", {}).get("global_cache_tag", -1) - != self.__global_cache_tag + != Querier.__global_cache_tag ): self.invalidate_core_call_cache(user_context, False) @@ -316,7 +315,7 @@ async def f(url: str, method: str) -> Response: **user_context.get("_default", {}).get("core_call_cache", {}), unique_key: response, }, - "global_cache_tag": self.__global_cache_tag, + "global_cache_tag": Querier.__global_cache_tag, } return response @@ -443,7 +442,7 @@ def invalidate_core_call_cache( user_context.get("_default", {}).get("keep_cache_alive", False) is not True ): # there can be race conditions here, but i think we can ignore them. - self.__global_cache_tag = get_timestamp_ms() + Querier.__global_cache_tag = get_timestamp_ms() user_context["_default"] = { **user_context.get("_default", {}), @@ -498,7 +497,7 @@ async def __send_request_helper( retry_info_map[url] = max_retries ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER + PROCESS_STATE.CALLING_SERVICE_IN_REQUEST_HELPER ) response = await http_function(url, method) if ("SUPERTOKENS_ENV" in environ) and ( @@ -524,9 +523,9 @@ async def __send_request_helper( raise Exception( "SuperTokens core threw an error for a " + method - + " request to path: " + + " request to path: '" + path.get_as_string_dangerous() - + " with status code: " + + "' with status code: " + str(response.status_code) + " and message: " + response.text # type: ignore diff --git a/supertokens_python/recipe/accountlinking/__init__.py b/supertokens_python/recipe/accountlinking/__init__.py new file mode 100644 index 000000000..c8aaedec6 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Callable, Union, Optional, Dict, Any, Awaitable + +from . import types +from ...types import User +from ..session.interfaces import SessionContainer + +from .recipe import AccountLinkingRecipe + +InputOverrideConfig = types.InputOverrideConfig +RecipeLevelUser = types.RecipeLevelUser +AccountInfoWithRecipeIdAndUserId = types.AccountInfoWithRecipeIdAndUserId +ShouldAutomaticallyLink = types.ShouldAutomaticallyLink +ShouldNotAutomaticallyLink = types.ShouldNotAutomaticallyLink + + +def init( + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, +): + return AccountLinkingRecipe.init( + on_account_linked, should_do_automatic_account_linking, override + ) diff --git a/supertokens_python/recipe/accountlinking/asyncio/__init__.py b/supertokens_python/recipe/accountlinking/asyncio/__init__.py new file mode 100644 index 000000000..549950447 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/asyncio/__init__.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from ..types import AccountInfoWithRecipeId +from supertokens_python.types import User, RecipeUserId +from ..recipe import AccountLinkingRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.asyncio import get_user + + +async def create_primary_user_id_or_link_accounts( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> User: + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + input_user=user, + tenant_id=tenant_id, + session=session, + user_context=user_context, + ) + if link_res.status == "NO_LINK": + return user + assert link_res.user is not None + return link_res.user + + +async def get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id: str, + recipe_user_id: RecipeUserId, + user_context: Optional[Dict[str, Any]] = None, +) -> Optional[User]: + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + return await AccountLinkingRecipe.get_instance().get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id=tenant_id, + user=user, + user_context=user_context, + ) + + +async def can_create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.can_create_primary_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.create_primary_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def can_link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().recipe_implementation.can_link_accounts( + recipe_user_id=recipe_user_id, + primary_user_id=primary_user_id, + user_context=user_context, + ) + + +async def link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return ( + await AccountLinkingRecipe.get_instance().recipe_implementation.link_accounts( + recipe_user_id=recipe_user_id, + primary_user_id=primary_user_id, + user_context=user_context, + ) + ) + + +async def unlink_account( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + if user_context is None: + user_context = {} + return ( + await AccountLinkingRecipe.get_instance().recipe_implementation.unlink_account( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + ) + + +async def is_sign_up_allowed( + tenant_id: str, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=new_user, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, + ) + + +async def is_sign_in_allowed( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Unknown recipeUserId") + + return await AccountLinkingRecipe.get_instance().is_sign_in_allowed( + user=user, + account_info=next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + session=session, + tenant_id=tenant_id, + sign_in_verifies_login_method=False, + user_context=user_context, + ) + + +async def is_email_change_allowed( + recipe_user_id: RecipeUserId, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + raise Exception("Passed in recipe user id does not exist") + + res = await AccountLinkingRecipe.get_instance().is_email_change_allowed( + user=user, + new_email=new_email, + is_verified=is_verified, + session=session, + user_context=user_context, + ) + return res.allowed diff --git a/supertokens_python/recipe/accountlinking/interfaces.py b/supertokens_python/recipe/accountlinking/interfaces.py new file mode 100644 index 000000000..2058765d3 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/interfaces.py @@ -0,0 +1,273 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional +from typing_extensions import Literal + +if TYPE_CHECKING: + from supertokens_python.types import ( + User, + RecipeUserId, + AccountInfo, + ) + + +class RecipeInterface(ABC): + @abstractmethod + async def get_users( + self, + tenant_id: str, + time_joined_order: Literal["ASC", "DESC"], + limit: Optional[int], + pagination_token: Optional[str], + include_recipe_ids: Optional[List[str]], + query: Optional[Dict[str, str]], + user_context: Dict[str, Any], + ) -> GetUsersResult: + pass + + @abstractmethod + async def can_create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + pass + + @abstractmethod + async def create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + pass + + @abstractmethod + async def can_link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + ]: + pass + + @abstractmethod + async def link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + ]: + pass + + @abstractmethod + async def unlink_account( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> UnlinkAccountOkResult: + pass + + @abstractmethod + async def get_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> Optional[User]: + pass + + @abstractmethod + async def list_users_by_account_info( + self, + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool, + user_context: Dict[str, Any], + ) -> List[User]: + pass + + @abstractmethod + async def delete_user( + self, + user_id: str, + remove_all_linked_accounts: bool, + user_context: Dict[str, Any], + ) -> None: + pass + + +class GetUsersResult: + def __init__(self, users: List[User], next_pagination_token: Optional[str]): + self.users = users + self.next_pagination_token = next_pagination_token + + +class CanCreatePrimaryUserOkResult: + def __init__(self, was_already_a_primary_user: bool): + self.status: Literal["OK"] = "OK" + self.was_already_a_primary_user = was_already_a_primary_user + + +class CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError: + def __init__(self, primary_user_id: str, description: str): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanCreatePrimaryUserAccountInfoAlreadyAssociatedError: + def __init__(self, primary_user_id: str, description: str): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CreatePrimaryUserOkResult: + def __init__(self, user: User, was_already_a_primary_user: bool): + self.status: Literal["OK"] = "OK" + self.user = user + self.was_already_a_primary_user = was_already_a_primary_user + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "user": self.user.to_json(), + "wasAlreadyAPrimaryUser": self.was_already_a_primary_user, + } + + +class CreatePrimaryUserRecipeUserIdAlreadyLinkedError: + def __init__(self, primary_user_id: str, description: Optional[str] = None): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CreatePrimaryUserAccountInfoAlreadyAssociatedError: + def __init__(self, primary_user_id: str, description: Optional[str] = None): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsOkResult: + def __init__(self, accounts_already_linked: bool): + self.status: Literal["OK"] = "OK" + self.accounts_already_linked = accounts_already_linked + + +class CanLinkAccountsRecipeUserIdAlreadyLinkedError: + def __init__( + self, primary_user_id: Optional[str] = None, description: Optional[str] = None + ): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsAccountInfoAlreadyAssociatedError: + def __init__( + self, primary_user_id: Optional[str] = None, description: Optional[str] = None + ): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class CanLinkAccountsInputUserNotPrimaryError: + def __init__(self, description: Optional[str] = None): + self.status: Literal[ + "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ] = "INPUT_USER_IS_NOT_A_PRIMARY_USER" + self.description = description + + +class LinkAccountsOkResult: + def __init__(self, accounts_already_linked: bool, user: User): + self.status: Literal["OK"] = "OK" + self.accounts_already_linked = accounts_already_linked + self.user = user + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "accountsAlreadyLinked": self.accounts_already_linked, + "user": self.user.to_json(), + } + + +class LinkAccountsRecipeUserIdAlreadyLinkedError: + def __init__( + self, + primary_user_id: str, + user: User, + description: str, + ): + self.status: Literal[ + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.user = user + self.description = description + + +class LinkAccountsAccountInfoAlreadyAssociatedError: + def __init__( + self, + primary_user_id: Optional[str] = None, + description: Optional[str] = None, + ): + self.status: Literal[ + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ] = "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + self.primary_user_id = primary_user_id + self.description = description + + +class LinkAccountsInputUserNotPrimaryError: + def __init__(self): + self.status: Literal[ + "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ] = "INPUT_USER_IS_NOT_A_PRIMARY_USER" + + +class UnlinkAccountOkResult: + def __init__(self, was_recipe_user_deleted: bool, was_linked: bool): + self.status: Literal["OK"] = "OK" + self.was_recipe_user_deleted = was_recipe_user_deleted + self.was_linked = was_linked diff --git a/supertokens_python/recipe/accountlinking/recipe.py b/supertokens_python/recipe/accountlinking/recipe.py new file mode 100644 index 000000000..2f98825be --- /dev/null +++ b/supertokens_python/recipe/accountlinking/recipe.py @@ -0,0 +1,1028 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from os import environ +from typing import Any, Dict, List, Union, TYPE_CHECKING, Optional, Callable, Awaitable +from supertokens_python.supertokens import Supertokens + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe_module import APIHandled, RecipeModule +from .utils import validate_and_normalise_user_input +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from .recipe_implementation import RecipeImplementation +from supertokens_python.querier import Querier +from supertokens_python.logger import ( + log_debug_message, +) +from supertokens_python.process_state import PROCESS_STATE, ProcessState +from typing_extensions import Literal + +from .types import ( + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, + AccountInfoWithRecipeIdAndUserId, + InputOverrideConfig, + AccountInfoWithRecipeId, + AccountInfo, +) + +from .interfaces import RecipeInterface + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + from supertokens_python.types import User, LoginMethod, RecipeUserId + from supertokens_python.recipe.session import SessionContainer + from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.emailverification.recipe import ( + EmailVerificationRecipe, + ) + + +class EmailChangeAllowedResult: + def __init__( + self, + allowed: bool, + reason: Literal["OK", "PRIMARY_USER_CONFLICT", "ACCOUNT_TAKEOVER_RISK"], + ): + self.allowed = allowed + self.reason: Literal[ + "OK", "PRIMARY_USER_CONFLICT", "ACCOUNT_TAKEOVER_RISK" + ] = reason + + +class TryLinkingByAccountInfoOrCreatePrimaryUserResult: + def __init__(self, status: Literal["OK", "NO_LINK"], user: Optional[User]): + self.status: Literal["OK", "NO_LINK"] = status + self.user = user + + +class AccountLinkingRecipe(RecipeModule): + recipe_id = "accountlinking" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, + ): + super().__init__(recipe_id, app_info) + self.config = validate_and_normalise_user_input( + app_info, on_account_linked, should_do_automatic_account_linking, override + ) + recipe_implementation: RecipeInterface = RecipeImplementation( + Querier.get_instance(recipe_id), self, self.config + ) + + self.recipe_implementation: RecipeInterface = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + + self.email_verification_recipe: EmailVerificationRecipe | None = None + + def register_email_verification_recipe( + self, email_verification_recipe: EmailVerificationRecipe + ): + self.email_verification_recipe = email_verification_recipe + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [] + + async def handle_api_request( + self, + request_id: str, + tenant_id: Optional[str], + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> Union[BaseResponse, None]: + raise Exception("Should never come here") + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Optional[InputOverrideConfig] = None, + ): + def func(app_info: AppInfo): + if AccountLinkingRecipe.__instance is None: + AccountLinkingRecipe.__instance = AccountLinkingRecipe( + AccountLinkingRecipe.recipe_id, + app_info, + on_account_linked, + should_do_automatic_account_linking, + override, + ) + return AccountLinkingRecipe.__instance + raise Exception( + None, + "Accountlinking recipe has already been initialised. Please check your code for bugs.", + ) + + return func + + @staticmethod + def get_instance() -> AccountLinkingRecipe: + if AccountLinkingRecipe.__instance is None: + AccountLinkingRecipe.init()(Supertokens.get_instance().app_info) + + assert AccountLinkingRecipe.__instance is not None + return AccountLinkingRecipe.__instance + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + AccountLinkingRecipe.__instance = None + + async def get_primary_user_that_can_be_linked_to_recipe_user_id( + self, + tenant_id: str, + user: User, + user_context: Dict[str, Any], + ) -> Optional[User]: + # First we check if this user itself is a primary user or not. If it is, we return that. + if user.is_primary_user: + return user + + # Then, we try and find a primary user based on the email / phone number / third party ID. + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=user.login_methods[0], + do_union_of_account_info=True, + user_context=user_context, + ) + + log_debug_message( + "getPrimaryUserThatCanBeLinkedToRecipeUserId found %d matching users" + % len(users) + ) + primary_users = [u for u in users if u.is_primary_user] + log_debug_message( + "getPrimaryUserThatCanBeLinkedToRecipeUserId found %d matching primary users" + % len(primary_users) + ) + + if len(primary_users) > 1: + # This means that the new user has account info such that it's + # spread across multiple primary user IDs. In this case, even + # if we return one of them, it won't be able to be linked anyway + # cause if we did, it would mean 2 primary users would have the + # same account info. So we return None + + # This being said, with the current set of auth recipes, it should + # never come here - cause: + # ----> If the recipeuserid is a passwordless user, then it can have either a phone + # email or both. If it has just one of them, then anyway 2 primary users can't + # exist with the same phone number / email. If it has both, then the only way + # that it can have multiple primary users returned is if there is another passwordless + # primary user with the same phone number - which is not possible, cause phone + # numbers are unique across passwordless users. + # + # ----> If the input is a third party user, then it has third party info and an email. Now there can be able to primary user with the same email, but + # there can't be another thirdparty user with the same third party info (since that is unique). + # Nor can there an email password primary user with the same email along with another + # thirdparty primary user with the same email (since emails can't be the same across primary users). + # + # ----> If the input is an email password user, then it has an email. There can't be multiple primary users with the same email anyway. + raise Exception( + "You found a bug. Please report it on github.com/supertokens/supertokens-node" + ) + + return primary_users[0] if len(primary_users) > 0 else None + + async def get_oldest_user_that_can_be_linked_to_recipe_user( + self, + tenant_id: str, + user: User, + user_context: Dict[str, Any], + ) -> Optional[User]: + # First we check if this user itself is a primary user or not. If it is, we return that since it cannot be linked to anything else + if user.is_primary_user: + return user + + # Then, we try and find matching users based on the email / phone number / third party ID. + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=user.login_methods[0], + do_union_of_account_info=True, + user_context=user_context, + ) + + log_debug_message( + f"getOldestUserThatCanBeLinkedToRecipeUser found {len(users)} matching users" + ) + + # Finally select the oldest one + oldest_user = min(users, key=lambda u: u.time_joined) if users else None + return oldest_user + + async def is_sign_in_allowed( + self, + user: User, + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + tenant_id: str, + session: Optional[SessionContainer], + sign_in_verifies_login_method: bool, + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state(PROCESS_STATE.IS_SIGN_IN_ALLOWED_CALLED) + if ( + user.is_primary_user + or user.login_methods[0].verified + or sign_in_verifies_login_method + ): + return True + + return await self.is_sign_in_up_allowed_helper( + account_info=account_info, + is_verified=user.login_methods[0].verified, + session=session, + tenant_id=tenant_id, + is_sign_in=True, + user=user, + user_context=user_context, + ) + + async def is_sign_up_allowed( + self, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer], + tenant_id: str, + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state(PROCESS_STATE.IS_SIGN_UP_ALLOWED_CALLED) + if new_user.email is not None and new_user.phone_number is not None: + # We do this check cause below when we call list_users_by_account_info, + # we only pass in one of email or phone number + raise Exception("Please pass one of email or phone number, not both") + + return await self.is_sign_in_up_allowed_helper( + account_info=new_user, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, + user=None, + is_sign_in=False, + ) + + async def is_sign_in_up_allowed_helper( + self, + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + is_verified: bool, + session: Optional[SessionContainer], + tenant_id: str, + is_sign_in: bool, + user: Optional[User], + user_context: Dict[str, Any], + ) -> bool: + ProcessState.get_instance().add_state( + PROCESS_STATE.IS_SIGN_IN_UP_ALLOWED_HELPER_CALLED + ) + + users = await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=account_info, + do_union_of_account_info=True, + user_context=user_context, + ) + + if not users: + log_debug_message( + "isSignInUpAllowedHelper returning true because no user with given account info" + ) + return True + + if is_sign_in and user is None: + raise Exception( + "This should never happen: isSignInUpAllowedHelper called with isSignIn: true, user: None" + ) + + if ( + len(users) == 1 + and is_sign_in + and user is not None + and users[0].id == user.id + ): + log_debug_message( + "isSignInUpAllowedHelper returning true because this is sign in and there is only a single user with the given account info" + ) + return True + + primary_users = [u for u in users if u.is_primary_user] + + if not primary_users: + log_debug_message("isSignInUpAllowedHelper no primary user exists") + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + None, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "isSignInUpAllowedHelper returning true because account linking is disabled" + ) + return True + + if not should_do_account_linking.should_require_verification: + log_debug_message( + "isSignInUpAllowedHelper returning true because dev does not require email verification" + ) + return True + + should_allow = True + for curr_user in users: + if session is not None and curr_user.id == session.get_user_id( + user_context + ): + # We do not consider the current session user to be conflicting + # This can be useful in cases where the current sign in will mark the session user as verified + continue + + this_iteration_is_verified = False + if account_info.email is not None: + if ( + curr_user.login_methods[0].has_same_email_as(account_info.email) + and curr_user.login_methods[0].verified + ): + log_debug_message( + "isSignInUpAllowedHelper found same email for another user and verified" + ) + this_iteration_is_verified = True + + if account_info.phone_number is not None: + if ( + curr_user.login_methods[0].has_same_phone_number_as( + account_info.phone_number + ) + and curr_user.login_methods[0].verified + ): + log_debug_message( + "isSignInUpAllowedHelper found same phone number for another user and verified" + ) + this_iteration_is_verified = True + + if not this_iteration_is_verified: + # even if one of the users is not verified, we do not allow sign up (see why above). + # Sure, this allows attackers to create email password accounts with an email + # to block actual users from signing up, but that's ok, since those + # users will just see an email already exists error and then will try another + # login method. They can also still just go through the password reset flow + # and then gain access to their email password account (which can then be verified). + log_debug_message( + "isSignInUpAllowedHelper returning false cause one of the other recipe level users is not verified" + ) + should_allow = False + break + + ProcessState.get_instance().add_state( + PROCESS_STATE.IS_SIGN_IN_UP_ALLOWED_NO_PRIMARY_USER_EXISTS + ) + log_debug_message(f"isSignInUpAllowedHelper returning {should_allow}") + return should_allow + else: + if len(primary_users) > 1: + raise Exception( + "You have found a bug. Please report to https://github.com/supertokens/supertokens-node/issues" + ) + + primary_user = primary_users[0] + log_debug_message("isSignInUpAllowedHelper primary user found") + + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + account_info + ), + primary_user, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "isSignInUpAllowedHelper returning true because account linking is disabled" + ) + return True + + if not should_do_account_linking.should_require_verification: + log_debug_message( + "isSignInUpAllowedHelper returning true because dev does not require email verification" + ) + return True + + if not is_verified: + log_debug_message( + "isSignInUpAllowedHelper returning false because new user's email is not verified, and primary user with the same email was found." + ) + return False + + if session is not None and primary_user.id == session.get_user_id( + user_context + ): + return True + + for login_method in primary_user.login_methods: + if login_method.email is not None: + if ( + login_method.has_same_email_as(account_info.email) + and login_method.verified + ): + log_debug_message( + "isSignInUpAllowedHelper returning true cause found same email for primary user and verified" + ) + return True + + if login_method.phone_number is not None: + if ( + login_method.has_same_phone_number_as(account_info.phone_number) + and login_method.verified + ): + log_debug_message( + "isSignInUpAllowedHelper returning true cause found same phone number for primary user and verified" + ) + return True + + log_debug_message( + "isSignInUpAllowedHelper returning false cause primary user does not have the same email or phone number that is verified" + ) + return False + + async def is_email_change_allowed( + self, + user: User, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> EmailChangeAllowedResult: + """ + The purpose of this function is to check if a recipe user ID's email + can be changed or not. There are two conditions for when it can't be changed: + - If the recipe user is a primary user, then we need to check that the new email + doesn't belong to any other primary user. If it does, we disallow the change + since multiple primary user's can't have the same account info. + + - If the recipe user is NOT a primary user, and if is_verified is false, then + we check if there exists a primary user with the same email, and if it does + we disallow the email change cause if this email is changed, and an email + verification email is sent, then the primary user may end up clicking + on the link by mistake, causing account linking to happen which can result + in account take over if this recipe user is malicious. + """ + + for tenant_id in user.tenant_ids: + existing_users_with_new_email = ( + await self.recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=new_email), + do_union_of_account_info=False, + user_context=user_context, + ) + ) + + other_users_with_new_email = [ + u for u in existing_users_with_new_email if u.id != user.id + ] + other_primary_user_for_new_email = [ + u for u in other_users_with_new_email if u.is_primary_user + ] + + if len(other_primary_user_for_new_email) > 1: + raise Exception( + "You found a bug. Please report it on github.com/supertokens/supertokens-core" + ) + + if user.is_primary_user: + if other_primary_user_for_new_email: + log_debug_message( + f"isEmailChangeAllowed: returning false cause email change will lead to two primary users having same email on {tenant_id}" + ) + return EmailChangeAllowedResult( + allowed=False, reason="PRIMARY_USER_CONFLICT" + ) + + if is_verified: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary, new email is verified and doesn't belong to any other primary user" + ) + continue + + if any( + lm.has_same_email_as(new_email) and lm.verified + for lm in user.login_methods + ): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary, new email is verified in another login method and doesn't belong to any other primary user" + ) + continue + + if not other_users_with_new_email: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is primary and the new email doesn't belong to any other user (primary or non-primary)" + ) + continue + + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + other_users_with_new_email[0].login_methods[0] + ), + user, + session, + tenant_id, + user_context, + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause linking is disabled" + ) + continue + + if not should_do_account_linking.should_require_verification: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause linking doesn't require email verification" + ) + continue + + log_debug_message( + f"isEmailChangeAllowed: returning false because the user hasn't verified the new email address and there exists another user with it on {tenant_id} and linking requires verification" + ) + return EmailChangeAllowedResult( + allowed=False, reason="ACCOUNT_TAKEOVER_RISK" + ) + else: + if is_verified: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary and new email is verified" + ) + continue + + if user.login_methods[0].has_same_email_as(new_email): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary and new email is same as the older one" + ) + continue + + if other_primary_user_for_new_email: + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId( + recipe_id=user.login_methods[0].recipe_id, + email=user.login_methods[0].email, + recipe_user_id=user.login_methods[0].recipe_user_id, + phone_number=user.login_methods[0].phone_number, + third_party=user.login_methods[0].third_party, + ), + other_primary_user_for_new_email[0], + session, + tenant_id, + user_context, + ) + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary there exists a primary user exists with the new email, but the dev does not have account linking enabled." + ) + continue + + if not should_do_account_linking.should_require_verification: + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary there exists a primary user exists with the new email, but the dev does not require email verification." + ) + continue + + log_debug_message( + "isEmailChangeAllowed: returning false cause input user is not a primary there exists a primary user exists with the new email." + ) + return EmailChangeAllowedResult( + allowed=False, reason="ACCOUNT_TAKEOVER_RISK" + ) + + log_debug_message( + f"isEmailChangeAllowed: can change on {tenant_id} cause input user is not a primary no primary user exists with the new email" + ) + continue + + log_debug_message( + "isEmailChangeAllowed: returning true cause email change can happen on all tenants the user is part of" + ) + return EmailChangeAllowedResult(allowed=True, reason="OK") + + # pylint:disable=no-self-use + async def verify_email_for_recipe_user_if_linked_accounts_are_verified( + self, + user: User, + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], + ) -> None: + if self.email_verification_recipe is None: + return + + if user.is_primary_user: + recipe_user_email: Optional[str] = None + is_already_verified = False + for lm in user.login_methods: + if lm.recipe_user_id.get_as_string() == recipe_user_id.get_as_string(): + recipe_user_email = lm.email + is_already_verified = lm.verified + break + + if recipe_user_email is not None: + if is_already_verified: + return + should_verify_email = False + for lm in user.login_methods: + if lm.has_same_email_as(recipe_user_email) and lm.verified: + should_verify_email = True + break + + if should_verify_email: + ev_recipe = self.email_verification_recipe.get_instance_or_throw() + resp = await ev_recipe.recipe_implementation.create_email_verification_token( + tenant_id=user.tenant_ids[0], + recipe_user_id=recipe_user_id, + email=recipe_user_email, + user_context=user_context, + ) + if resp.status == "OK": + # we purposely pass in false below cause we don't want account + # linking to happen + await ev_recipe.recipe_implementation.verify_email_using_token( + tenant_id=user.tenant_ids[0], + token=resp.token, + attempt_account_linking=False, + user_context=user_context, + ) + + async def should_become_primary_user( + self, + user: User, + tenant_id: str, + session: Optional[SessionContainer], + user_context: Dict[str, Any], + ) -> bool: + should_do_account_linking = ( + await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + user.login_methods[0] + ), + None, + session, + tenant_id, + user_context, + ) + ) + + if isinstance(should_do_account_linking, ShouldNotAutomaticallyLink): + log_debug_message( + "should_become_primary_user returning false because shouldAutomaticallyLink is false" + ) + return False + + if ( + should_do_account_linking.should_require_verification + and not user.login_methods[0].verified + ): + log_debug_message( + "should_become_primary_user returning false because shouldRequireVerification is true but the login method is not verified" + ) + return False + + log_debug_message("should_become_primary_user returning true") + return True + + async def try_linking_by_account_info_or_create_primary_user( + self, + input_user: User, + session: Optional[SessionContainer], + tenant_id: str, + user_context: Dict[str, Any], + ) -> TryLinkingByAccountInfoOrCreatePrimaryUserResult: + tries = 0 + while tries < 100: + tries += 1 + primary_user_that_can_be_linked_to_the_input_user = ( + await self.get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id=tenant_id, + user=input_user, + user_context=user_context, + ) + ) + if primary_user_that_can_be_linked_to_the_input_user is not None: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: got primary user we can try linking" + ) + # we check if the input_user and primary_user_that_can_be_linked_to_the_input_user are linked based on recipeIds because the input_user obj could be outdated + if not any( + lm.recipe_user_id.get_as_string() + == input_user.login_methods[0].recipe_user_id.get_as_string() + for lm in primary_user_that_can_be_linked_to_the_input_user.login_methods + ): + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + input_user.login_methods[0] + ), + primary_user_that_can_be_linked_to_the_input_user, + session, + tenant_id, + user_context, + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldAutomaticallyLink is false" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + account_info_verified_in_prim_user = any( + ( + input_user.login_methods[0].email is not None + and lm.has_same_email_as(input_user.login_methods[0].email) + ) + or ( + input_user.login_methods[0].phone_number is not None + and lm.has_same_phone_number_as( + input_user.login_methods[0].phone_number + ) + and lm.verified + ) + for lm in primary_user_that_can_be_linked_to_the_input_user.login_methods + ) + if should_do_account_linking.should_require_verification and ( + not input_user.login_methods[0].verified + or not account_info_verified_in_prim_user + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldRequireVerification is true but the login method is not verified in the new or the primary user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking" + ) + link_accounts_result = await self.recipe_implementation.link_accounts( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + primary_user_id=primary_user_that_can_be_linked_to_the_input_user.id, + user_context=user_context, + ) + + if link_accounts_result.status == "OK": + log_debug_message( + "try_linking_by_account_info_or_create_primary_user successfully linked" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user already linked to another user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + else: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=input_user + ) + + oldest_user_that_can_be_linked_to_the_input_user = ( + await self.get_oldest_user_that_can_be_linked_to_recipe_user( + tenant_id=tenant_id, + user=input_user, + user_context=user_context, + ) + ) + if ( + oldest_user_that_can_be_linked_to_the_input_user is not None + and oldest_user_that_can_be_linked_to_the_input_user.id != input_user.id + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: got an older user we can try linking" + ) + should_make_older_user_primary = await self.should_become_primary_user( + oldest_user_that_can_be_linked_to_the_input_user, + tenant_id, + session, + user_context, + ) + if should_make_older_user_primary: + create_primary_user_result = await self.recipe_implementation.create_primary_user( + recipe_user_id=oldest_user_that_can_be_linked_to_the_input_user.login_methods[ + 0 + ].recipe_user_id, + user_context=user_context, + ) + if ( + create_primary_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + or create_primary_user_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + f"try_linking_by_account_info_or_create_primary_user: retrying because createPrimaryUser returned {create_primary_user_result.status}" + ) + continue + should_do_account_linking = await self.config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + input_user.login_methods[0] + ), + create_primary_user_result.user, + session, + tenant_id, + user_context, + ) + + if isinstance( + should_do_account_linking, ShouldNotAutomaticallyLink + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldAutomaticallyLink is false" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + if ( + should_do_account_linking.should_require_verification + and not input_user.login_methods[0].verified + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: not linking because shouldRequireVerification is true but the login method is not verified" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="NO_LINK", user=None + ) + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking" + ) + link_accounts_result = ( + await self.recipe_implementation.link_accounts( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + primary_user_id=create_primary_user_result.user.id, + user_context=user_context, + ) + ) + + if link_accounts_result.status == "OK": + log_debug_message( + "try_linking_by_account_info_or_create_primary_user successfully linked" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user already linked to another user" + ) + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=link_accounts_result.user + ) + elif ( + link_accounts_result.status + == "INPUT_USER_IS_NOT_A_PRIMARY_USER" + ): + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + else: + log_debug_message( + "try_linking_by_account_info_or_create_primary_user linking failed because of a race condition" + ) + continue + + log_debug_message( + "try_linking_by_account_info_or_create_primary_user: trying to make the current user primary" + ) + + if await self.should_become_primary_user( + input_user, tenant_id, session, user_context + ): + create_primary_user_result = ( + await self.recipe_implementation.create_primary_user( + recipe_user_id=input_user.login_methods[0].recipe_user_id, + user_context=user_context, + ) + ) + + if ( + create_primary_user_result.status + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + or create_primary_user_result.status + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + continue + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", + user=create_primary_user_result.user, + ) + else: + return TryLinkingByAccountInfoOrCreatePrimaryUserResult( + status="OK", user=input_user + ) + + raise Exception( + "This should never happen: ran out of retries for try_linking_by_account_info_or_create_primary_user" + ) diff --git a/supertokens_python/recipe/accountlinking/recipe_implementation.py b/supertokens_python/recipe/accountlinking/recipe_implementation.py new file mode 100644 index 000000000..c14cbf22c --- /dev/null +++ b/supertokens_python/recipe/accountlinking/recipe_implementation.py @@ -0,0 +1,365 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union, List, Optional +from typing_extensions import Literal + + +from .interfaces import ( + RecipeInterface, + GetUsersResult, + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + UnlinkAccountOkResult, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from .types import AccountLinkingConfig, RecipeLevelUser, AccountInfo +from supertokens_python.types import User, RecipeUserId + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + from .recipe import AccountLinkingRecipe + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + recipe_instance: AccountLinkingRecipe, + config: AccountLinkingConfig, + ): + super().__init__() + self.querier = querier + self.recipe_instance = recipe_instance + self.config = config + + async def get_users( + self, + tenant_id: str, + time_joined_order: Literal["ASC", "DESC"], + limit: Optional[int], + pagination_token: Optional[str], + include_recipe_ids: Optional[List[str]], + query: Optional[Dict[str, str]], + user_context: Dict[str, Any], + ) -> GetUsersResult: + include_recipe_ids_str = None + if include_recipe_ids is not None: + include_recipe_ids_str = ",".join(include_recipe_ids) + + params: Dict[str, Any] = { + "timeJoinedOrder": time_joined_order, + } + if limit is not None: + params["limit"] = limit + if pagination_token is not None: + params["paginationToken"] = pagination_token + if include_recipe_ids_str is not None: + params["includeRecipeIds"] = include_recipe_ids_str + if query: + params.update(query) + + response = await self.querier.send_get_request( + NormalisedURLPath(f"/{tenant_id or 'public'}/users"), params, user_context + ) + + return GetUsersResult( + users=[User.from_json(u) for u in response["users"]], + next_pagination_token=response.get("nextPaginationToken"), + ) + + async def can_create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CanCreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/accountlinking/user/primary/check"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + + if response["status"] == "OK": + return CanCreatePrimaryUserOkResult(response["wasAlreadyAPrimaryUser"]) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return CanCreatePrimaryUserAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + else: + raise Exception(f"Unknown response status: {response['status']}") + + async def create_primary_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[ + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserAccountInfoAlreadyAssociatedError, + ]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/primary"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + + if response["status"] == "OK": + return CreatePrimaryUserOkResult( + User.from_json(response["user"]), + response["wasAlreadyAPrimaryUser"], + ) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CreatePrimaryUserRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return CreatePrimaryUserAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + else: + raise Exception(f"Unknown response status: {response['status']}") + + async def can_link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CanLinkAccountsOkResult, + CanLinkAccountsRecipeUserIdAlreadyLinkedError, + CanLinkAccountsAccountInfoAlreadyAssociatedError, + CanLinkAccountsInputUserNotPrimaryError, + ]: + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/accountlinking/user/link/check"), + { + "recipeUserId": recipe_user_id.get_as_string(), + "primaryUserId": primary_user_id, + }, + user_context, + ) + + if response["status"] == "OK": + return CanLinkAccountsOkResult(response["accountsAlreadyLinked"]) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_PRIMARY_USER_ID_ERROR" + ): + return CanLinkAccountsRecipeUserIdAlreadyLinkedError( + response["primaryUserId"], response["description"] + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return CanLinkAccountsAccountInfoAlreadyAssociatedError( + response["primaryUserId"], response["description"] + ) + elif response["status"] == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return CanLinkAccountsInputUserNotPrimaryError(response["description"]) + else: + raise Exception(f"Unknown response status: {response['status']}") + + async def link_accounts( + self, + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Dict[str, Any], + ) -> Union[ + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsInputUserNotPrimaryError, + ]: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/link"), + { + "recipeUserId": recipe_user_id.get_as_string(), + "primaryUserId": primary_user_id, + }, + user_context, + ) + + if response["status"] in [ + "OK", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR", + ]: + response["user"] = User.from_json(response["user"]) + + if response["status"] == "OK": + user = response["user"] + if not response["accountsAlreadyLinked"]: + await self.recipe_instance.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + updated_user = await self.get_user( + user_id=primary_user_id, + user_context=user_context, + ) + if updated_user is None: + raise Exception("This error should never be thrown") + user = updated_user + + login_method_info = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == recipe_user_id.get_as_string() + ), + None, + ) + if login_method_info is None: + raise Exception("This error should never be thrown") + + await self.config.on_account_linked( + user, + RecipeLevelUser.from_login_method(login_method_info), + user_context, + ) + + response["user"] = user + + if response["status"] == "OK": + return LinkAccountsOkResult( + user=response["user"], + accounts_already_linked=response["accountsAlreadyLinked"], + ) + elif ( + response["status"] + == "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkAccountsRecipeUserIdAlreadyLinkedError( + primary_user_id=response["primaryUserId"], + user=response["user"], + description=response["description"], + ) + elif ( + response["status"] + == "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR" + ): + return LinkAccountsAccountInfoAlreadyAssociatedError( + primary_user_id=response["primaryUserId"], + description=response["description"], + ) + elif response["status"] == "INPUT_USER_IS_NOT_A_PRIMARY_USER": + return LinkAccountsInputUserNotPrimaryError() + else: + raise Exception(f"Unknown response status: {response['status']}") + + async def unlink_account( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> UnlinkAccountOkResult: + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/accountlinking/user/unlink"), + { + "recipeUserId": recipe_user_id.get_as_string(), + }, + user_context, + ) + return UnlinkAccountOkResult( + response["wasRecipeUserDeleted"], response["wasLinked"] + ) + + async def get_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> Optional[User]: + response = await self.querier.send_get_request( + NormalisedURLPath("/user/id"), + { + "userId": user_id, + }, + user_context, + ) + if response["status"] == "OK": + return User.from_json(response["user"]) + return None + + async def list_users_by_account_info( + self, + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool, + user_context: Dict[str, Any], + ) -> List[User]: + params: Dict[str, Any] = { + "doUnionOfAccountInfo": do_union_of_account_info, + } + if account_info.email is not None: + params["email"] = account_info.email + if account_info.phone_number is not None: + params["phoneNumber"] = account_info.phone_number + + if account_info.third_party: + params["thirdPartyId"] = account_info.third_party.id + params["thirdPartyUserId"] = account_info.third_party.user_id + + response = await self.querier.send_get_request( + NormalisedURLPath(f"/{tenant_id or 'public'}/users/by-accountinfo"), + params, + user_context, + ) + + return [User.from_json(u) for u in response["users"]] + + async def delete_user( + self, + user_id: str, + remove_all_linked_accounts: bool, + user_context: Dict[str, Any], + ) -> None: + await self.querier.send_post_request( + NormalisedURLPath("/user/remove"), + { + "userId": user_id, + "removeAllLinkedAccounts": remove_all_linked_accounts, + }, + user_context, + ) diff --git a/supertokens_python/recipe/accountlinking/syncio/__init__.py b/supertokens_python/recipe/accountlinking/syncio/__init__.py new file mode 100644 index 000000000..6de893c1f --- /dev/null +++ b/supertokens_python/recipe/accountlinking/syncio/__init__.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional + +from supertokens_python.async_to_sync_wrapper import sync + +from ..types import AccountInfoWithRecipeId +from supertokens_python.types import RecipeUserId +from supertokens_python.recipe.session import SessionContainer + + +def create_primary_user_id_or_link_accounts( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import ( + create_primary_user_id_or_link_accounts as async_create_primary_user_id_or_link_accounts, + ) + + return sync( + async_create_primary_user_id_or_link_accounts( + tenant_id, recipe_user_id, session, user_context + ) + ) + + +def get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id: str, + recipe_user_id: RecipeUserId, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import ( + get_primary_user_that_can_be_linked_to_recipe_user_id as async_get_primary_user_that_can_be_linked_to_recipe_user_id, + ) + + return sync( + async_get_primary_user_that_can_be_linked_to_recipe_user_id( + tenant_id, recipe_user_id, user_context + ) + ) + + +def can_create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import can_create_primary_user as async_can_create_primary_user + + return sync(async_can_create_primary_user(recipe_user_id, user_context)) + + +def create_primary_user( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import create_primary_user as async_create_primary_user + + return sync(async_create_primary_user(recipe_user_id, user_context)) + + +def can_link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import can_link_accounts as async_can_link_accounts + + return sync(async_can_link_accounts(recipe_user_id, primary_user_id, user_context)) + + +def link_accounts( + recipe_user_id: RecipeUserId, + primary_user_id: str, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import link_accounts as async_link_accounts + + return sync(async_link_accounts(recipe_user_id, primary_user_id, user_context)) + + +def unlink_account( + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None +): + from ..asyncio import unlink_account as async_unlink_account + + return sync(async_unlink_account(recipe_user_id, user_context)) + + +def is_sign_up_allowed( + tenant_id: str, + new_user: AccountInfoWithRecipeId, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_sign_up_allowed as async_is_sign_up_allowed + + return sync( + async_is_sign_up_allowed( + tenant_id, new_user, is_verified, session, user_context + ) + ) + + +def is_sign_in_allowed( + tenant_id: str, + recipe_user_id: RecipeUserId, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_sign_in_allowed as async_is_sign_in_allowed + + return sync( + async_is_sign_in_allowed(tenant_id, recipe_user_id, session, user_context) + ) + + +def is_email_change_allowed( + recipe_user_id: RecipeUserId, + new_email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +): + from ..asyncio import is_email_change_allowed as async_is_email_change_allowed + + return sync( + async_is_email_change_allowed( + recipe_user_id, new_email, is_verified, session, user_context + ) + ) diff --git a/supertokens_python/recipe/accountlinking/types.py b/supertokens_python/recipe/accountlinking/types.py new file mode 100644 index 000000000..0037c8614 --- /dev/null +++ b/supertokens_python/recipe/accountlinking/types.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from typing import Callable, Dict, Any, Union, Optional, List, TYPE_CHECKING, Awaitable +from typing_extensions import Literal +from supertokens_python.recipe.accountlinking.interfaces import ( + RecipeInterface, +) +from supertokens_python.types import AccountInfo + +if TYPE_CHECKING: + from supertokens_python.types import ( + RecipeUserId, + ThirdPartyInfo, + User, + LoginMethod, + ) + from supertokens_python.recipe.session import SessionContainer + + +class AccountInfoWithRecipeId(AccountInfo): + def __init__( + self, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__(email, phone_number, third_party) + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id + + def to_json(self) -> Dict[str, Any]: + return { + **super().to_json(), + "recipeId": self.recipe_id, + } + + +class RecipeLevelUser(AccountInfoWithRecipeId): + def __init__( + self, + tenant_ids: List[str], + time_joined: int, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__(recipe_id, email, phone_number, third_party) + self.tenant_ids = tenant_ids + self.time_joined = time_joined + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id + + @staticmethod + def from_login_method( + login_method: LoginMethod, + ) -> RecipeLevelUser: + return RecipeLevelUser( + tenant_ids=login_method.tenant_ids, + time_joined=login_method.time_joined, + recipe_id=login_method.recipe_id, + email=login_method.email, + phone_number=login_method.phone_number, + third_party=login_method.third_party, + ) + + +class AccountInfoWithRecipeIdAndUserId(AccountInfoWithRecipeId): + def __init__( + self, + recipe_user_id: Optional[RecipeUserId], + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + super().__init__(recipe_id, email, phone_number, third_party) + self.recipe_user_id = recipe_user_id + + @staticmethod + def from_account_info_or_login_method( + account_info: Union[AccountInfoWithRecipeId, LoginMethod], + ) -> AccountInfoWithRecipeIdAndUserId: + from supertokens_python.types import ( + LoginMethod as LM, + ) + + return AccountInfoWithRecipeIdAndUserId( + recipe_id=account_info.recipe_id, + email=account_info.email, + phone_number=account_info.phone_number, + third_party=account_info.third_party, + recipe_user_id=( + account_info.recipe_user_id if isinstance(account_info, LM) else None + ), + ) + + +class ShouldNotAutomaticallyLink: + def __init__(self): + pass + + +class ShouldAutomaticallyLink: + def __init__(self, should_require_verification: bool): + self.should_require_verification = should_require_verification + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + ): + self.functions = functions + + +class InputOverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + ): + self.functions = functions + + +class AccountLinkingConfig: + def __init__( + self, + on_account_linked: Callable[ + [User, RecipeLevelUser, Dict[str, Any]], Awaitable[None] + ], + should_do_automatic_account_linking: Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ], + override: OverrideConfig, + ): + self.on_account_linked = on_account_linked + self.should_do_automatic_account_linking = should_do_automatic_account_linking + self.override = override diff --git a/supertokens_python/recipe/accountlinking/utils.py b/supertokens_python/recipe/accountlinking/utils.py new file mode 100644 index 000000000..6aa12765a --- /dev/null +++ b/supertokens_python/recipe/accountlinking/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, Awaitable + +if TYPE_CHECKING: + from .types import ( + AccountLinkingConfig, + User, + RecipeLevelUser, + AccountInfoWithRecipeIdAndUserId, + SessionContainer, + ShouldNotAutomaticallyLink, + ShouldAutomaticallyLink, + InputOverrideConfig, + ) + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + +async def default_on_account_linked(_: User, __: RecipeLevelUser, ___: Dict[str, Any]): + pass + + +_did_use_default_should_do_automatic_account_linking: bool = True + + +async def default_should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + ___: Optional[User], + ____: Optional[SessionContainer], + _____: str, + ______: Dict[str, Any], +) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + from .types import ( + ShouldNotAutomaticallyLink as SNAL, + ) + + return SNAL() + + +def recipe_init_defined_should_do_automatic_account_linking() -> bool: + return not _did_use_default_should_do_automatic_account_linking + + +def validate_and_normalise_user_input( + _: AppInfo, + on_account_linked: Optional[ + Callable[[User, RecipeLevelUser, Dict[str, Any]], Awaitable[None]] + ] = None, + should_do_automatic_account_linking: Optional[ + Callable[ + [ + AccountInfoWithRecipeIdAndUserId, + Optional[User], + Optional[SessionContainer], + str, + Dict[str, Any], + ], + Awaitable[Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]], + ] + ] = None, + override: Union[InputOverrideConfig, None] = None, +) -> AccountLinkingConfig: + from .types import ( + OverrideConfig, + InputOverrideConfig as IOC, + AccountLinkingConfig as ALC, + ) + + global _did_use_default_should_do_automatic_account_linking + if override is None: + override = IOC() + + _did_use_default_should_do_automatic_account_linking = ( + should_do_automatic_account_linking is None + ) + + return ALC( + override=OverrideConfig(functions=override.functions), + on_account_linked=( + default_on_account_linked + if on_account_linked is None + else on_account_linked + ), + should_do_automatic_account_linking=( + default_should_do_automatic_account_linking + if should_do_automatic_account_linking is None + else should_do_automatic_account_linking + ), + ) diff --git a/supertokens_python/recipe/dashboard/api/__init__.py b/supertokens_python/recipe/dashboard/api/__init__.py index 067fa84ec..3ae3a869a 100644 --- a/supertokens_python/recipe/dashboard/api/__init__.py +++ b/supertokens_python/recipe/dashboard/api/__init__.py @@ -31,7 +31,6 @@ from .users_count_get import handle_users_count_get_api from .users_get import handle_users_get_api from .validate_key import handle_validate_key_api -from .list_tenants import handle_list_tenants_api __all__ = [ "handle_dashboard_api", @@ -54,5 +53,4 @@ "handle_emailpassword_signout_api", "handle_get_tags", "handle_analytics_post", - "handle_list_tenants_api", ] diff --git a/supertokens_python/recipe/dashboard/api/list_tenants.py b/supertokens_python/recipe/dashboard/api/list_tenants.py deleted file mode 100644 index 6d519c101..000000000 --- a/supertokens_python/recipe/dashboard/api/list_tenants.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, List - -from supertokens_python.recipe.dashboard.interfaces import DashboardListTenantItem - -if TYPE_CHECKING: - from supertokens_python.recipe.dashboard.interfaces import ( - APIOptions, - APIInterface, - ) - from supertokens_python.types import APIResponse - -from supertokens_python.recipe.multitenancy.asyncio import list_all_tenants -from supertokens_python.recipe.dashboard.interfaces import ( - DashboardListTenantsGetResponse, -) - - -async def handle_list_tenants_api( - _api_implementation: APIInterface, - _tenant_id: str, - _api_options: APIOptions, - user_context: Dict[str, Any], -) -> APIResponse: - tenants = await list_all_tenants(user_context) - - final_tenants: List[DashboardListTenantItem] = [] - - for current_tenant in tenants.tenants: - dashboard_tenant = DashboardListTenantItem( - tenant_id=current_tenant.tenant_id, - emailpassword=current_tenant.emailpassword, - passwordless=current_tenant.passwordless, - third_party=current_tenant.third_party, - ) - final_tenants.append(dashboard_tenant) - - return DashboardListTenantsGetResponse(final_tenants) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py new file mode 100644 index 000000000..eb6ac3fbd --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/create_or_update_third_party_config.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.multitenancy.asyncio import ( + get_tenant, + create_or_update_third_party_config, +) +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.recipe.thirdparty.providers.utils import do_post_request +from supertokens_python.types import APIResponse +from supertokens_python.utils import encode_base64 +from ...interfaces import APIInterface, APIOptions +import asyncio +import json + + +class CreateOrUpdateThirdPartyConfigOkResult(APIResponse): + def __init__(self, created_new: bool): + self.status = "OK" + self.created_new = created_new + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "createdNew": self.created_new} + + +class CreateOrUpdateThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateOrUpdateThirdPartyConfigBoxyError(APIResponse): + def __init__(self, message: str): + self.status = "BOXY_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "message": self.message} + + +async def handle_create_or_update_third_party_config( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateOrUpdateThirdPartyConfigOkResult, + CreateOrUpdateThirdPartyConfigUnknownTenantError, + CreateOrUpdateThirdPartyConfigBoxyError, +]: + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is required") + provider_config = request_body.get("providerConfig") + + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return CreateOrUpdateThirdPartyConfigUnknownTenantError() + + if len(tenant_res.third_party_providers) == 0: + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers or [] + for provider in static_providers: + if ( + provider.include_in_non_public_tenants_by_default + or tenant_id == DEFAULT_TENANT_ID + ): + await create_or_update_third_party_config( + tenant_id, + ProviderConfig(third_party_id=provider.config.third_party_id), + None, + user_context, + ) + await asyncio.sleep(0.5) # 500ms delay + + if provider_config["thirdPartyId"].startswith("boxy-saml"): + boxy_url = provider_config["clients"][0]["additionalConfig"]["boxyURL"] + boxy_api_key = provider_config["clients"][0]["additionalConfig"]["boxyAPIKey"] + provider_config["clients"][0]["additionalConfig"]["boxyAPIKey"] = None + + if boxy_api_key and provider_config["clients"][0]["additionalConfig"][ + "samlInputType" + ] in [ + "xml", + "url", + ]: + request_body_input: Dict[str, Any] = { + "name": "", + "label": "", + "description": "", + "tenant": provider_config["clients"][0] + .get("additionalConfig", {}) + .get("boxyTenant") + or f"{tenant_id}-{provider_config['thirdPartyId']}", + "product": provider_config["clients"][0]["additionalConfig"].get( + "boxyProduct" + ) + or "supertokens", + "defaultRedirectUrl": provider_config["clients"][0]["additionalConfig"][ + "redirectURLs" + ][0], + "forceAuthn": False, + "encodedRawMetadata": encode_base64( + provider_config["clients"][0]["additionalConfig"].get("samlXML", "") + ), + "redirectUrl": json.dumps( + provider_config["clients"][0]["additionalConfig"]["redirectURLs"] + ), + "metadataUrl": provider_config["clients"][0]["additionalConfig"].get( + "samlURL", "" + ), + } + + normalised_domain = NormalisedURLDomain(boxy_url) + normalised_base_path = NormalisedURLPath(boxy_url) + connections_path = NormalisedURLPath("/api/v1/saml/config") + + status, resp = await do_post_request( + normalised_domain.get_as_string_dangerous() + + normalised_base_path.append( + connections_path + ).get_as_string_dangerous(), + body_params=request_body_input, + headers={"Authorization": f"Api-Key {boxy_api_key}"}, + ) + + if status != 200: + if status == 401: + return CreateOrUpdateThirdPartyConfigBoxyError("Invalid API Key") + return CreateOrUpdateThirdPartyConfigBoxyError( + resp.get("message", "Unknown error") + ) + + provider_config["clients"][0]["clientId"] = resp["clientID"] + provider_config["clients"][0]["clientSecret"] = resp["clientSecret"] + + third_party_res = await create_or_update_third_party_config( + tenant_id, ProviderConfig.from_json(provider_config), None, user_context + ) + + return CreateOrUpdateThirdPartyConfigOkResult(third_party_res.created_new) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py b/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py new file mode 100644 index 000000000..e84c5c873 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/create_tenant.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.multitenancy.asyncio import create_or_update_tenant +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class CreateTenantOkResult(APIResponse): + def __init__(self, created_new: bool): + self.status = "OK" + self.created_new = created_new + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "createdNew": self.created_new} + + +class CreateTenantMultitenancyNotEnabledError(APIResponse): + def __init__(self): + self.status = "MULTITENANCY_NOT_ENABLED_IN_CORE_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateTenantTenantIdAlreadyExistsError(APIResponse): + def __init__(self): + self.status = "TENANT_ID_ALREADY_EXISTS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class CreateTenantInvalidTenantIdError(APIResponse): + def __init__(self, message: str): + self.status = "INVALID_TENANT_ID_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "message": self.message} + + +async def create_tenant( + _: APIInterface, + __: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateTenantOkResult, + CreateTenantMultitenancyNotEnabledError, + CreateTenantTenantIdAlreadyExistsError, + CreateTenantInvalidTenantIdError, +]: + request_body = await options.request.json() + if request_body is None: + raise BadInputError("Request body is required") + tenant_id = request_body.get("tenantId") + config = {k: v for k, v in request_body.items() if k != "tenantId"} + + if not isinstance(tenant_id, str) or tenant_id == "": + raise BadInputError("Missing required parameter 'tenantId'") + + try: + tenant_res = await create_or_update_tenant( + tenant_id, TenantConfigCreateOrUpdate.from_json(config), user_context + ) + except Exception as err: + err_msg: str = str(err) + if "SuperTokens core threw an error for a " in err_msg: + if "with status code: 402" in err_msg: + return CreateTenantMultitenancyNotEnabledError() + if "with status code: 400" in err_msg: + return CreateTenantInvalidTenantIdError( + err_msg.split(" and message: ")[1] + ) + raise err + + if not tenant_res.created_new: + return CreateTenantTenantIdAlreadyExistsError() + + return CreateTenantOkResult(tenant_res.created_new) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py b/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py new file mode 100644 index 000000000..f14bf1e1f --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/delete_tenant.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from typing_extensions import Literal +from supertokens_python.recipe.multitenancy.asyncio import delete_tenant +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class DeleteTenantOkResult(APIResponse): + def __init__(self, did_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_exist = did_exist + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "didExist": self.did_exist} + + +class DeleteTenantCannotDeletePublicTenantError(APIResponse): + def __init__(self): + self.status: Literal[ + "CANNOT_DELETE_PUBLIC_TENANT_ERROR" + ] = "CANNOT_DELETE_PUBLIC_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def delete_tenant_api( + _: APIInterface, + tenant_id: str, + __: APIOptions, + user_context: Dict[str, Any], +) -> Union[DeleteTenantOkResult, DeleteTenantCannotDeletePublicTenantError]: + try: + delete_tenant_res = await delete_tenant(tenant_id, user_context) + return DeleteTenantOkResult(delete_tenant_res.did_exist) + except Exception as err: + err_msg: str = str(err) + if ( + "SuperTokens core threw an error for a " in err_msg + and "with status code: 403" in err_msg + ): + return DeleteTenantCannotDeletePublicTenantError() + raise err diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py new file mode 100644 index 000000000..84688e19c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/delete_third_party_config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multifactorauth.types import FactorIds +from supertokens_python.recipe.multitenancy.asyncio import ( + get_tenant, + create_or_update_third_party_config, + create_or_update_tenant, + delete_third_party_config, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +import asyncio + + +class DeleteThirdPartyConfigOkResult(APIResponse): + def __init__(self, did_config_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_config_exist = did_config_exist + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "didConfigExist": self.did_config_exist} + + +class DeleteThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def delete_third_party_config_api( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[DeleteThirdPartyConfigOkResult, DeleteThirdPartyConfigUnknownTenantError]: + third_party_id = options.request.get_query_param("thirdPartyId") + + if not tenant_id or not third_party_id: + raise_bad_input_exception( + "Missing required parameter 'tenantId' or 'thirdPartyId'" + ) + + assert third_party_id is not None + + tenant_res = await get_tenant(tenant_id, user_context) + if tenant_res is None: + return DeleteThirdPartyConfigUnknownTenantError() + + third_party_ids_from_core = [ + provider.third_party_id for provider in tenant_res.third_party_providers + ] + + if len(third_party_ids_from_core) == 0: + # This means that the tenant was using the static list of providers, we need to add them all before deleting one + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = ( + mt_recipe.static_third_party_providers if mt_recipe.config else [] + ) + static_provider_ids = [ + provider.config.third_party_id for provider in static_providers + ] + + for provider_id in static_provider_ids: + await create_or_update_third_party_config( + tenant_id, + ProviderConfig(third_party_id=provider_id), + None, + user_context, + ) + # Delay after each provider to avoid rate limiting + await asyncio.sleep(0.5) # 500ms + elif ( + len(third_party_ids_from_core) == 1 + and third_party_ids_from_core[0] == third_party_id + ): + if tenant_res.first_factors is None: + # Add all static first factors except thirdparty + await create_or_update_tenant( + tenant_id, + TenantConfigCreateOrUpdate( + first_factors=[ + FactorIds.EMAILPASSWORD, + FactorIds.OTP_PHONE, + FactorIds.OTP_EMAIL, + FactorIds.LINK_PHONE, + FactorIds.LINK_EMAIL, + ] + ), + user_context, + ) + elif "thirdparty" in tenant_res.first_factors: + # Add all static first factors except thirdparty + new_first_factors = [ + factor for factor in tenant_res.first_factors if factor != "thirdparty" + ] + await create_or_update_tenant( + tenant_id, + TenantConfigCreateOrUpdate(first_factors=new_first_factors), + user_context, + ) + + result = await delete_third_party_config(tenant_id, third_party_id, user_context) + return DeleteThirdPartyConfigOkResult(result.did_config_exist) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py new file mode 100644 index 000000000..17f42b7e7 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_tenant_info.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Union +from typing_extensions import Literal + +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python import Supertokens +from supertokens_python.types import APIResponse +from .utils import ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) +from supertokens_python.recipe.thirdparty.providers.config_utils import ( + find_and_create_provider_instance, + merge_providers_from_core_and_static, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.querier import Querier +from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from ...interfaces import APIInterface, APIOptions, CoreConfigFieldInfo + + +from typing import List, Optional + + +class ThirdPartyProvider: + def __init__(self, third_party_id: str, name: str): + self.third_party_id = third_party_id + self.name = name + + def to_json(self) -> Dict[str, Any]: + return {"thirdPartyId": self.third_party_id, "name": self.name} + + +class TenantInfo: + def __init__( + self, + tenant_id: str, + third_party: List[ThirdPartyProvider], + first_factors: List[str], + required_secondary_factors: Optional[List[str]], + core_config: List[CoreConfigFieldInfo], + user_count: int, + ): + self.tenant_id = tenant_id + self.third_party = third_party + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors + self.core_config = core_config + self.user_count = user_count + + def to_json(self) -> Dict[str, Any]: + return { + "tenantId": self.tenant_id, + "thirdParty": { + "providers": [provider.to_json() for provider in self.third_party] + }, + "firstFactors": self.first_factors, + "requiredSecondaryFactors": self.required_secondary_factors, + "coreConfig": [field.to_json() for field in self.core_config], + "userCount": self.user_count, + } + + +class GetTenantInfoOkResult(APIResponse): + def __init__(self, tenant: TenantInfo): + self.status: Literal["OK"] = "OK" + self.tenant = tenant + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "tenant": self.tenant.to_json()} + + +class GetTenantInfoUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def get_tenant_info( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[GetTenantInfoOkResult, GetTenantInfoUnknownTenantError]: + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return GetTenantInfoUnknownTenantError() + + first_factors = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + ) + + user_count = await Supertokens.get_instance().get_user_count( + None, tenant_id, user_context + ) + + providers_from_core = tenant_res.third_party_providers + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers + + merged_providers_from_core_and_static = merge_providers_from_core_and_static( + providers_from_core, static_providers, tenant_id == DEFAULT_TENANT_ID + ) + + querier = Querier.get_instance(options.recipe_id) + core_config = await querier.send_get_request( + NormalisedURLPath(f"/{tenant_id}/recipe/dashboard/tenant/core-config"), + {}, + user_context, + ) + + providers: List[ThirdPartyProvider] = [] + for provider in merged_providers_from_core_and_static: + try: + provider_instance = await find_and_create_provider_instance( + merged_providers_from_core_and_static, + provider.config.third_party_id, + ( + provider.config.clients[0].client_type + if provider.config.clients + else None + ), + user_context, + ) + assert provider_instance is not None + if provider_instance.config.name is None: + raise Exception("Falling back to exception block") + providers.append( + ThirdPartyProvider( + provider.config.third_party_id, + provider_instance.config.name, + ), + ) + except Exception: + providers.append( + ThirdPartyProvider( + provider.config.third_party_id, provider.config.third_party_id + ) + ) + + tenant = TenantInfo( + tenant_id, + providers, + first_factors, + tenant_res.required_secondary_factors, + [CoreConfigFieldInfo.from_json(field) for field in core_config["config"]], + user_count, + ) + + return GetTenantInfoOkResult(tenant) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py new file mode 100644 index 000000000..ef050343b --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/get_third_party_config.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.thirdparty import ( + ProviderClientConfig, + ProviderConfig, + ProviderInput, +) +from supertokens_python.recipe.thirdparty.provider import CommonProviderConfig, Provider +from supertokens_python.recipe.thirdparty.providers.utils import do_get_request +from supertokens_python.types import APIResponse +from supertokens_python.recipe.thirdparty.providers.config_utils import ( + find_and_create_provider_instance, + merge_providers_from_core_and_static, +) +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.normalised_url_domain import NormalisedURLDomain +from ...interfaces import APIInterface, APIOptions + + +class ProviderConfigResponse(APIResponse): + def __init__( + self, + provider_config: ProviderConfig, + is_get_authorisation_redirect_url_overridden: bool, + is_exchange_auth_code_for_oauth_tokens_overridden: bool, + is_get_user_info_overridden: bool, + ): + self.provider_config = provider_config + self.is_get_authorisation_redirect_url_overridden = ( + is_get_authorisation_redirect_url_overridden + ) + self.is_exchange_auth_code_for_oauth_tokens_overridden = ( + is_exchange_auth_code_for_oauth_tokens_overridden + ) + self.is_get_user_info_overridden = is_get_user_info_overridden + + def to_json(self) -> Dict[str, Any]: + json_response = self.provider_config.to_json() + json_response[ + "isGetAuthorisationRedirectUrlOverridden" + ] = self.is_get_authorisation_redirect_url_overridden + json_response[ + "isExchangeAuthCodeForOAuthTokensOverridden" + ] = self.is_exchange_auth_code_for_oauth_tokens_overridden + json_response["isGetUserInfoOverridden"] = self.is_get_user_info_overridden + return { + "status": "OK", + "providerConfig": json_response, + } + + +class GetThirdPartyConfigUnknownTenantError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +async def get_third_party_config( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ProviderConfigResponse, GetThirdPartyConfigUnknownTenantError]: + tenant_res = await get_tenant(tenant_id, user_context) + + if tenant_res is None: + return GetThirdPartyConfigUnknownTenantError() + + third_party_id = options.request.get_query_param("thirdPartyId") + + if third_party_id is None: + raise_bad_input_exception("Please provide thirdPartyId") + + providers_from_core = tenant_res.third_party_providers + mt_recipe = MultitenancyRecipe.get_instance() + static_providers = mt_recipe.static_third_party_providers or [] + + additional_config: Optional[Dict[str, Any]] = None + + providers_from_core = [ + provider + for provider in providers_from_core + if provider.third_party_id == third_party_id + ] + + if not providers_from_core: + providers_from_core.append(ProviderConfig(third_party_id=third_party_id)) + + if third_party_id in ["okta", "active-directory", "boxy-saml", "google-workspaces"]: + if third_party_id == "okta": + okta_domain = options.request.get_query_param("oktaDomain") + if okta_domain is not None: + additional_config = {"oktaDomain": okta_domain} + elif third_party_id == "active-directory": + directory_id = options.request.get_query_param("directoryId") + if directory_id is not None: + additional_config = {"directoryId": directory_id} + elif third_party_id == "boxy-saml": + boxy_url = options.request.get_query_param("boxyUrl") + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_url is not None: + additional_config = {"boxyURL": boxy_url} + if boxy_api_key is not None: + additional_config["boxyAPIKey"] = boxy_api_key + elif third_party_id == "google-workspaces": + hd = options.request.get_query_param("hd") + if hd is not None: + additional_config = {"hd": hd} + + if additional_config is not None: + providers_from_core[0].oidc_discovery_endpoint = None + providers_from_core[0].authorization_endpoint = None + providers_from_core[0].token_endpoint = None + providers_from_core[0].user_info_endpoint = None + + if providers_from_core[0].clients is not None: + for existing_client in providers_from_core[0].clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + providers_from_core[0].clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + static_providers = [ + provider + for provider in static_providers + if provider.config.third_party_id == third_party_id + ] + + if not static_providers and third_party_id == "apple": + static_providers.append( + ProviderInput( + config=ProviderConfig( + third_party_id="apple", + clients=[ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id" + ) + ], + ) + ) + ) + + additional_config = { + "teamId": "", + "keyId": "", + "privateKey": "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----", + } + + if len(static_providers) == 1 and additional_config is not None: + static_providers[0].config.oidc_discovery_endpoint = None + static_providers[0].config.authorization_endpoint = None + static_providers[0].config.token_endpoint = None + static_providers[0].config.user_info_endpoint = None + if static_providers[0].config.clients is not None: + for existing_client in static_providers[0].config.clients: + if existing_client.additional_config is not None: + existing_client.additional_config = { + **existing_client.additional_config, + **additional_config, + } + else: + existing_client.additional_config = additional_config + else: + static_providers[0].config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=additional_config, + ) + ] + + merged_providers_from_core_and_static = merge_providers_from_core_and_static( + providers_from_core, static_providers, True + ) + + if len(merged_providers_from_core_and_static) != 1: + raise Exception("should never come here!") + + for merged_provider in merged_providers_from_core_and_static: + if merged_provider.config.third_party_id == third_party_id: + if not merged_provider.config.clients: + merged_provider.config.clients = [ + ProviderClientConfig( + client_id="nonguessable-temporary-client-id", + additional_config=( + additional_config if additional_config is not None else None + ), + ) + ] + clients: List[ProviderClientConfig] = [] + common_provider_config: CommonProviderConfig = CommonProviderConfig( + third_party_id=third_party_id + ) + is_get_authorisation_redirect_url_overridden = False + is_exchange_auth_code_for_oauth_tokens_overridden = False + is_get_user_info_overridden = False + + for provider in merged_providers_from_core_and_static: + if provider.config.third_party_id == third_party_id: + found_correct_config = False + + for client in provider.config.clients or []: + try: + provider_instance = await find_and_create_provider_instance( + merged_providers_from_core_and_static, + third_party_id, + client.client_type, + user_context, + ) + assert provider_instance is not None + clients.append( + ProviderClientConfig( + client_id=provider_instance.config.client_id, + client_secret=provider_instance.config.client_secret, + scope=provider_instance.config.scope, + client_type=provider_instance.config.client_type, + additional_config=provider_instance.config.additional_config, + force_pkce=provider_instance.config.force_pkce, + ) + ) + # common_provider_config = CommonProviderConfig( + # third_party_id=provider_instance.config.third_party_id, + # name=provider_instance.config.name, + # authorization_endpoint=provider_instance.config.authorization_endpoint, + # authorization_endpoint_query_params=provider_instance.config.authorization_endpoint_query_params, + # token_endpoint=provider_instance.config.token_endpoint, + # token_endpoint_body_params=provider_instance.config.token_endpoint_body_params, + # user_info_endpoint=provider_instance.config.user_info_endpoint, + # user_info_endpoint_query_params=provider_instance.config.user_info_endpoint_query_params, + # user_info_endpoint_headers=provider_instance.config.user_info_endpoint_headers, + # jwks_uri=provider_instance.config.jwks_uri, + # oidc_discovery_endpoint=provider_instance.config.oidc_discovery_endpoint, + # user_info_map=provider_instance.config.user_info_map, + # require_email=provider_instance.config.require_email, + # validate_id_token_payload=provider_instance.config.validate_id_token_payload, + # validate_access_token=provider_instance.config.validate_access_token, + # generate_fake_email=provider_instance.config.generate_fake_email, + # ) + common_provider_config = provider_instance.config + + if provider.override is not None: + before_override = Provider( + config=provider_instance.config, + id=provider_instance.id, + ) + after_override = provider.override(before_override) + + if ( + before_override.get_authorisation_redirect_url + != after_override.get_authorisation_redirect_url + ): + is_get_authorisation_redirect_url_overridden = True + if ( + before_override.exchange_auth_code_for_oauth_tokens + != after_override.exchange_auth_code_for_oauth_tokens + ): + is_exchange_auth_code_for_oauth_tokens_overridden = True + if ( + before_override.get_user_info + != after_override.get_user_info + ): + is_get_user_info_overridden = True + + found_correct_config = True + except Exception: + clients.append(client) + + if not found_correct_config: + common_provider_config = provider.config + + break + + if additional_config and "privateKey" in additional_config: + additional_config["privateKey"] = "" + + temp_clients = [ + client + for client in clients + if client.client_id == "nonguessable-temporary-client-id" + ] + + final_clients = [ + client + for client in clients + if client.client_id != "nonguessable-temporary-client-id" + ] + if not final_clients: + final_clients = [ + ProviderClientConfig( + client_id="", + client_secret="", + additional_config=additional_config, + client_type=temp_clients[0].client_type, + force_pkce=temp_clients[0].force_pkce, + scope=temp_clients[0].scope, + ) + ] + + if third_party_id.startswith("boxy-saml"): + boxy_api_key = options.request.get_query_param("boxyAPIKey") + if boxy_api_key and final_clients[0].client_id: + assert isinstance(final_clients[0].additional_config, dict) + boxy_url = final_clients[0].additional_config["boxyURL"] + normalised_domain = NormalisedURLDomain(boxy_url) + normalised_base_path = NormalisedURLPath(boxy_url) + connections_path = NormalisedURLPath("/api/v1/saml/config") + + resp = await do_get_request( + normalised_domain.get_as_string_dangerous() + + normalised_base_path.append( + connections_path + ).get_as_string_dangerous(), + {"clientID": final_clients[0].client_id}, + {"Authorization": f"Api-Key {boxy_api_key}"}, + ) + + json_response = resp + final_clients[0].additional_config.update( + { + "redirectURLs": json_response["redirectUrl"], + "boxyTenant": json_response["tenant"], + "boxyProduct": json_response["product"], + } + ) + + provider_config = ProviderConfig( + third_party_id=third_party_id, + clients=final_clients, + authorization_endpoint=common_provider_config.authorization_endpoint, + authorization_endpoint_query_params=common_provider_config.authorization_endpoint_query_params, + token_endpoint=common_provider_config.token_endpoint, + token_endpoint_body_params=common_provider_config.token_endpoint_body_params, + user_info_endpoint=common_provider_config.user_info_endpoint, + user_info_endpoint_query_params=common_provider_config.user_info_endpoint_query_params, + user_info_endpoint_headers=common_provider_config.user_info_endpoint_headers, + jwks_uri=common_provider_config.jwks_uri, + oidc_discovery_endpoint=common_provider_config.oidc_discovery_endpoint, + user_info_map=common_provider_config.user_info_map, + require_email=common_provider_config.require_email, + validate_id_token_payload=common_provider_config.validate_id_token_payload, + validate_access_token=common_provider_config.validate_access_token, + generate_fake_email=common_provider_config.generate_fake_email, + name=common_provider_config.name, + ) + + return ProviderConfigResponse( + provider_config=provider_config, + is_get_authorisation_redirect_url_overridden=is_get_authorisation_redirect_url_overridden, + is_exchange_auth_code_for_oauth_tokens_overridden=is_exchange_auth_code_for_oauth_tokens_overridden, + is_get_user_info_overridden=is_get_user_info_overridden, + ) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py b/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py new file mode 100644 index 000000000..95300b99a --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/list_all_tenants_with_login_methods.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class TenantWithLoginMethods: + def __init__(self, tenant_id: str, first_factors: List[str]): + self.tenant_id = tenant_id + self.first_factors = first_factors + + +class ListAllTenantsWithLoginMethodsOkResult(APIResponse): + def __init__(self, tenants: List[TenantWithLoginMethods]): + self.status = "OK" + self.tenants = tenants + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "tenants": [ + {"tenantId": tenant.tenant_id, "firstFactors": tenant.first_factors} + for tenant in self.tenants + ], + } + + +async def list_all_tenants_with_login_methods( + _: APIInterface, + __: str, + ___: APIOptions, + user_context: Dict[str, Any], +) -> ListAllTenantsWithLoginMethodsOkResult: + tenants_res = ( + await MultitenancyRecipe.get_instance().recipe_implementation.list_all_tenants( + user_context + ) + ) + final_tenants: List[TenantWithLoginMethods] = [] + + for current_tenant in tenants_res.tenants: + login_methods = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + current_tenant + ) + ) + + final_tenants.append( + TenantWithLoginMethods( + tenant_id=current_tenant.tenant_id, + first_factors=login_methods, + ) + ) + + return ListAllTenantsWithLoginMethodsOkResult(tenants=final_tenants) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py new file mode 100644 index 000000000..dfc279d42 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_core_config.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions + + +class UpdateTenantCoreConfigOkResult(APIResponse): + status: Literal["OK"] = "OK" + + def __init__(self): + self.status = "OK" + + def to_json(self) -> Dict[str, Literal["OK"]]: + return {"status": self.status} + + +class UpdateTenantCoreConfigUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +class UpdateTenantCoreConfigInvalidConfigErrorResult(APIResponse): + status: Literal["INVALID_CONFIG_ERROR"] = "INVALID_CONFIG_ERROR" + + def __init__(self, message: str): + self.status = "INVALID_CONFIG_ERROR" + self.message = message + + def to_json(self) -> Dict[str, Union[Literal["INVALID_CONFIG_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +async def update_tenant_core_config( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantCoreConfigOkResult, + UpdateTenantCoreConfigUnknownTenantErrorResult, + UpdateTenantCoreConfigInvalidConfigErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + name = request_body["name"] + value = request_body["value"] + + mt_recipe = MultitenancyRecipe.get_instance() + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + if tenant_res is None: + return UpdateTenantCoreConfigUnknownTenantErrorResult() + + try: + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + core_config={name: value}, + ), + user_context=user_context, + ) + except Exception as err: + err_msg = str(err) + if ( + "SuperTokens core threw an error for a " in err_msg + and "with status code: 400" in err_msg + ): + return UpdateTenantCoreConfigInvalidConfigErrorResult( + message=err_msg.split(" and message: ")[1] + ) + raise err + + return UpdateTenantCoreConfigOkResult() diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py new file mode 100644 index 000000000..2c77eca4c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_first_factor.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_factor_not_available_message, + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class UpdateTenantFirstFactorOkResult(APIResponse): + status: Literal["OK"] = "OK" + + def __init__(self): + self.status = "OK" + + def to_json(self) -> Dict[str, Literal["OK"]]: + return {"status": self.status} + + +class UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult(APIResponse): + status: Literal[ + "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + ] = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + + def __init__(self, message: str): + self.status = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + self.message = message + + def to_json( + self, + ) -> Dict[str, Union[Literal["RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +class UpdateTenantFirstFactorUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +async def update_tenant_first_factor( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantFirstFactorOkResult, + UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult, + UpdateTenantFirstFactorUnknownTenantErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + factor_id = request_body["factorId"] + enable = request_body["enable"] + + mt_recipe = MultitenancyRecipe.get_instance() + + if enable is True: + if factor_id not in mt_recipe.all_available_first_factors: + return UpdateTenantFirstFactorRecipeNotConfiguredOnBackendSdkErrorResult( + message=get_factor_not_available_message( + factor_id, mt_recipe.all_available_first_factors + ) + ) + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + + if tenant_res is None: + return UpdateTenantFirstFactorUnknownTenantErrorResult() + + first_factors = ( + get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + ) + + if enable is True: + if factor_id not in first_factors: + first_factors.append(factor_id) + else: + first_factors = [f for f in first_factors if f != factor_id] + + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + ), + user_context=user_context, + ) + + return UpdateTenantFirstFactorOkResult() diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py new file mode 100644 index 000000000..ad4947754 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/update_tenant_secondary_factor.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.types import APIResponse +from ...interfaces import APIInterface, APIOptions +from .utils import ( + get_factor_not_available_message, + get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init, +) + + +class UpdateTenantSecondaryFactorOkResult(APIResponse): + status: Literal["OK"] = "OK" + is_mfa_requirements_for_auth_overridden: bool + + def __init__(self, is_mfa_requirements_for_auth_overridden: bool): + self.status = "OK" + self.is_mfa_requirements_for_auth_overridden = ( + is_mfa_requirements_for_auth_overridden + ) + + def to_json(self) -> Dict[str, Union[Literal["OK"], bool]]: + return { + "status": self.status, + "isMFARequirementsForAuthOverridden": self.is_mfa_requirements_for_auth_overridden, + } + + +class UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult( + APIResponse +): + status: Literal[ + "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + ] = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + + def __init__(self, message: str): + self.status = "RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR" + self.message = message + + def to_json( + self, + ) -> Dict[str, Union[Literal["RECIPE_NOT_CONFIGURED_ON_BACKEND_SDK_ERROR"], str]]: + return {"status": self.status, "message": self.message} + + +class UpdateTenantSecondaryFactorMfaNotInitializedErrorResult(APIResponse): + status: Literal["MFA_NOT_INITIALIZED_ERROR"] = "MFA_NOT_INITIALIZED_ERROR" + + def __init__(self): + self.status = "MFA_NOT_INITIALIZED_ERROR" + + def to_json(self) -> Dict[str, Literal["MFA_NOT_INITIALIZED_ERROR"]]: + return {"status": self.status} + + +class UpdateTenantSecondaryFactorUnknownTenantErrorResult(APIResponse): + status: Literal["UNKNOWN_TENANT_ERROR"] = "UNKNOWN_TENANT_ERROR" + + def __init__(self): + self.status = "UNKNOWN_TENANT_ERROR" + + def to_json(self) -> Dict[str, Literal["UNKNOWN_TENANT_ERROR"]]: + return {"status": self.status} + + +async def update_tenant_secondary_factor( + _: APIInterface, + tenant_id: str, + options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + UpdateTenantSecondaryFactorOkResult, + UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult, + UpdateTenantSecondaryFactorMfaNotInitializedErrorResult, + UpdateTenantSecondaryFactorUnknownTenantErrorResult, +]: + request_body = await options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is required") + factor_id = request_body["factorId"] + enable = request_body["enable"] + + mt_recipe = MultitenancyRecipe.get_instance() + mfa_instance = MultiFactorAuthRecipe.get_instance() + + if mfa_instance is None: + return UpdateTenantSecondaryFactorMfaNotInitializedErrorResult() + + tenant_res = await mt_recipe.recipe_implementation.get_tenant( + tenant_id=tenant_id, user_context=user_context + ) + + if tenant_res is None: + return UpdateTenantSecondaryFactorUnknownTenantErrorResult() + + if enable is True: + all_available_secondary_factors = ( + await mfa_instance.get_all_available_secondary_factor_ids(tenant_res) + ) + + if factor_id not in all_available_secondary_factors: + return ( + UpdateTenantSecondaryFactorRecipeNotConfiguredOnBackendSdkErrorResult( + message=get_factor_not_available_message( + factor_id, all_available_secondary_factors + ) + ) + ) + + secondary_factors = await get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_res + ) + + if enable is True: + if factor_id not in secondary_factors: + secondary_factors.append(factor_id) + else: + secondary_factors = [f for f in secondary_factors if f != factor_id] + + await mt_recipe.recipe_implementation.create_or_update_tenant( + tenant_id=tenant_id, + config=TenantConfigCreateOrUpdate( + required_secondary_factors=secondary_factors if secondary_factors else None, + ), + user_context=user_context, + ) + + return UpdateTenantSecondaryFactorOkResult( + is_mfa_requirements_for_auth_overridden=mfa_instance.is_get_mfa_requirements_for_auth_overridden + ) diff --git a/supertokens_python/recipe/dashboard/api/multitenancy/utils.py b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py new file mode 100644 index 000000000..9782cea50 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/multitenancy/utils.py @@ -0,0 +1,102 @@ +from typing import List +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import FactorIds +from supertokens_python.recipe.multifactorauth.utils import ( + is_factor_configured_for_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + + +def get_normalised_first_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_details_from_core: TenantConfig, +) -> List[str]: + first_factors: List[str] + + mt_instance = MultitenancyRecipe.get_instance() + + if tenant_details_from_core.first_factors is not None: + first_factors = ( + tenant_details_from_core.first_factors + ) # highest priority, config from core + elif mt_instance.static_first_factors is not None: + first_factors = mt_instance.static_first_factors # next priority, static config + else: + # Fallback to all available factors (de-duplicated) + first_factors = list(set(mt_instance.all_available_first_factors)) + + # we now filter out all available first factors by checking if they are valid because + # we want to return the ones that can work. this would be based on what recipes are enabled + # on the core and also first_factors configured in the core and supertokens.init + # Also, this way, in the front end, the developer can just check for first_factors for + # enabled recipes in all cases irrespective of whether they are using MFA or not + valid_first_factors: List[str] = [] + for factor_id in first_factors: + if is_factor_configured_for_tenant( + all_available_first_factors=mt_instance.all_available_first_factors, + first_factors=first_factors, + factor_id=factor_id, + ): + valid_first_factors.append(factor_id) + + return valid_first_factors + + +def get_factor_not_available_message( + factor_id: str, available_factors: List[str] +) -> str: + recipe_name = factor_id_to_recipe(factor_id) + if recipe_name != "Passwordless": + return f"Please initialise {recipe_name} recipe to be able to use this login method" + + passwordless_factors = [ + FactorIds.LINK_EMAIL, + FactorIds.LINK_PHONE, + FactorIds.OTP_EMAIL, + FactorIds.OTP_PHONE, + ] + passwordless_factors_not_available = [ + f for f in passwordless_factors if f not in available_factors + ] + + if len(passwordless_factors_not_available) == 4: + return ( + "Please initialise Passwordless recipe to be able to use this login method" + ) + + flow_type, contact_method = factor_id.split("-") + return f"Please ensure that Passwordless recipe is initialised with contactMethod: {contact_method.upper()} and flowType: {'USER_INPUT_CODE' if flow_type == 'otp' else 'MAGIC_LINK'}" + + +def factor_id_to_recipe(factor_id: str) -> str: + factor_id_to_recipe_map = { + "emailpassword": "Emailpassword", + "thirdparty": "ThirdParty", + "otp-email": "Passwordless", + "otp-phone": "Passwordless", + "link-email": "Passwordless", + "link-phone": "Passwordless", + "totp": "Totp", + } + + return factor_id_to_recipe_map.get(factor_id, "") + + +async def get_normalised_required_secondary_factors_based_on_tenant_config_from_core_and_sdk_init( + tenant_details_from_core: TenantConfig, +) -> List[str]: + mfa_instance = MultiFactorAuthRecipe.get_instance() + + if mfa_instance is None: + return [] + + secondary_factors = await mfa_instance.get_all_available_secondary_factor_ids( + tenant_details_from_core + ) + secondary_factors = [ + factor_id + for factor_id in secondary_factors + if factor_id in (tenant_details_from_core.required_secondary_factors or []) + ] + + return secondary_factors diff --git a/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py new file mode 100644 index 000000000..3fe8a55af --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/user/create/emailpassword_user.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, Union + +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.emailpassword.asyncio import sign_up +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + SignUpOkResult, +) +from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe +from supertokens_python.types import APIResponse, User, RecipeUserId + + +class CreateEmailPasswordUserOkResponse(APIResponse): + def __init__(self, user: User, recipe_user_id: RecipeUserId): + self.status = "OK" + self.user = user + self.recipe_user_id = recipe_user_id + + def to_json(self): + return { + "status": self.status, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } + + +class CreateEmailPasswordUserFeatureNotEnabledResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreateEmailPasswordUserEmailAlreadyExistsResponse(APIResponse): + def __init__(self): + self.status = "EMAIL_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreateEmailPasswordUserEmailValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status = "EMAIL_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +class CreateEmailPasswordUserPasswordValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status = "PASSWORD_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +async def create_email_password_user( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + CreateEmailPasswordUserOkResponse, + CreateEmailPasswordUserFeatureNotEnabledResponse, + CreateEmailPasswordUserEmailAlreadyExistsResponse, + CreateEmailPasswordUserEmailValidationErrorResponse, + CreateEmailPasswordUserPasswordValidationErrorResponse, +]: + email_password: EmailPasswordRecipe + try: + email_password = EmailPasswordRecipe.get_instance() + except Exception: + return CreateEmailPasswordUserFeatureNotEnabledResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is missing") + + email = request_body.get("email") + password = request_body.get("password") + + if not isinstance(email, str): + raise BadInputError( + "Required parameter 'email' is missing or has an invalid type" + ) + + if not isinstance(password, str): + raise BadInputError( + "Required parameter 'password' is missing or has an invalid type" + ) + + email_form_field = next( + field + for field in email_password.config.sign_up_feature.form_fields + if field.id == "email" + ) + validate_email_error = await email_form_field.validate(email, tenant_id) + + if validate_email_error is not None: + return CreateEmailPasswordUserEmailValidationErrorResponse(validate_email_error) + + password_form_field = next( + field + for field in email_password.config.sign_up_feature.form_fields + if field.id == "password" + ) + validate_password_error = await password_form_field.validate(password, tenant_id) + + if validate_password_error is not None: + return CreateEmailPasswordUserPasswordValidationErrorResponse( + validate_password_error + ) + + response = await sign_up(tenant_id, email, password) + + if isinstance(response, SignUpOkResult): + return CreateEmailPasswordUserOkResponse(response.user, response.recipe_user_id) + elif isinstance(response, EmailAlreadyExistsError): + return CreateEmailPasswordUserEmailAlreadyExistsResponse() + else: + raise Exception( + "This should never happen: EmailPassword.sign_up threw a session user related error without passing a session" + ) diff --git a/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py new file mode 100644 index 000000000..00e470194 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/user/create/passwordless_user.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Union +from typing_extensions import Literal + +from supertokens_python.exceptions import BadInputError +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.passwordless import ( + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, +) +from supertokens_python.recipe.passwordless.asyncio import signinup +from supertokens_python.recipe.passwordless.recipe import PasswordlessRecipe +from supertokens_python.types import APIResponse, User, RecipeUserId +from phonenumbers import parse as parse_phone_number, format_number, PhoneNumberFormat + + +class CreatePasswordlessUserOkResponse(APIResponse): + def __init__( + self, + created_new_recipe_user: bool, + user: User, + recipe_user_id: RecipeUserId, + ): + self.status: Literal["OK"] = "OK" + self.created_new_recipe_user = created_new_recipe_user + self.user = user + self.recipe_user_id = recipe_user_id + + def to_json(self): + return { + "status": self.status, + "createdNewRecipeUser": self.created_new_recipe_user, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } + + +class CreatePasswordlessUserFeatureNotEnabledResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class CreatePasswordlessUserEmailValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status: Literal["EMAIL_VALIDATION_ERROR"] = "EMAIL_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +class CreatePasswordlessUserPhoneValidationErrorResponse(APIResponse): + def __init__(self, message: str): + self.status: Literal["PHONE_VALIDATION_ERROR"] = "PHONE_VALIDATION_ERROR" + self.message = message + + def to_json(self): + return {"status": self.status, "message": self.message} + + +async def create_passwordless_user( + _: APIInterface, + tenant_id: str, + api_options: APIOptions, + __: Dict[str, Any], +) -> Union[ + CreatePasswordlessUserOkResponse, + CreatePasswordlessUserFeatureNotEnabledResponse, + CreatePasswordlessUserEmailValidationErrorResponse, + CreatePasswordlessUserPhoneValidationErrorResponse, +]: + passwordless_recipe: PasswordlessRecipe + try: + passwordless_recipe = PasswordlessRecipe.get_instance() + except Exception: + return CreatePasswordlessUserFeatureNotEnabledResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise BadInputError("Request body is missing") + + email = request_body.get("email") + phone_number = request_body.get("phoneNumber") + + if (email is not None and phone_number is not None) or ( + email is None and phone_number is None + ): + raise BadInputError("Please provide exactly one of email or phoneNumber") + + if email is not None and ( + isinstance(passwordless_recipe.config.contact_config, ContactEmailOnlyConfig) + or isinstance( + passwordless_recipe.config.contact_config, ContactEmailOrPhoneConfig + ) + ): + email = email.strip() + validation_error = ( + await passwordless_recipe.config.contact_config.validate_email_address( + email, tenant_id + ) + ) + if validation_error is not None: + return CreatePasswordlessUserEmailValidationErrorResponse(validation_error) + + if phone_number is not None and ( + isinstance(passwordless_recipe.config.contact_config, ContactPhoneOnlyConfig) + or isinstance( + passwordless_recipe.config.contact_config, ContactEmailOrPhoneConfig + ) + ): + validation_error = ( + await passwordless_recipe.config.contact_config.validate_phone_number( + phone_number, tenant_id + ) + ) + if validation_error is not None: + return CreatePasswordlessUserPhoneValidationErrorResponse(validation_error) + + try: + parsed_phone_number = parse_phone_number(phone_number) + phone_number = format_number(parsed_phone_number, PhoneNumberFormat.E164) + except Exception: + # This can happen if the user has provided their own impl of validate_phone_number + # and the phone number is valid according to their impl, but not according to the phonenumbers lib. + phone_number = phone_number.strip() + + response = await signinup(tenant_id, email=email, phone_number=phone_number) + + return CreatePasswordlessUserOkResponse( + created_new_recipe_user=response.created_new_recipe_user, + user=response.user, + recipe_user_id=response.recipe_user_id, + ) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py index 44e59e427..82868584c 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_delete.py @@ -2,7 +2,7 @@ from ...interfaces import APIInterface, APIOptions, UserDeleteAPIResponse from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python import Supertokens +from supertokens_python.asyncio import delete_user async def handle_user_delete( @@ -12,10 +12,24 @@ async def handle_user_delete( _user_context: Dict[str, Any], ) -> UserDeleteAPIResponse: user_id = api_options.request.get_query_param("userId") + remove_all_linked_accounts_query_value = api_options.request.get_query_param( + "removeAllLinkedAccounts" + ) - if user_id is None: + if remove_all_linked_accounts_query_value is not None: + remove_all_linked_accounts_query_value = ( + remove_all_linked_accounts_query_value.strip().lower() + ) + + remove_all_linked_accounts = ( + True + if remove_all_linked_accounts_query_value is None + else remove_all_linked_accounts_query_value == "true" + ) + + if user_id is None or user_id == "": raise_bad_input_exception("Missing required parameter 'userId'") - await Supertokens.get_instance().delete_user(user_id, _user_context) + await delete_user(user_id, remove_all_linked_accounts) return UserDeleteAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py index 54e0c2a9e..2462b9ffa 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_get.py @@ -7,6 +7,7 @@ UserEmailVerifyGetAPIResponse, FeatureNotEnabledError, ) +from supertokens_python.types import RecipeUserId from typing import Union, Dict, Any @@ -18,15 +19,17 @@ async def handle_user_email_verify_get( user_context: Dict[str, Any], ) -> Union[UserEmailVerifyGetAPIResponse, FeatureNotEnabledError]: req = api_options.request - user_id = req.get_query_param("userId") + recipe_user_id = req.get_query_param("recipeUserId") - if user_id is None: - raise_bad_input_exception("Missing required parameter 'userId'") + if recipe_user_id is None: + raise_bad_input_exception("Missing required parameter 'recipeUserId'") try: - EmailVerificationRecipe.get_instance() + EmailVerificationRecipe.get_instance_or_throw() except Exception: return FeatureNotEnabledError() - is_verified = await is_email_verified(user_id, user_context=user_context) + is_verified = await is_email_verified( + RecipeUserId(recipe_user_id), user_context=user_context + ) return UserEmailVerifyGetAPIResponse(is_verified) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py index fd4443ae1..56206839d 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_put.py @@ -11,6 +11,8 @@ VerifyEmailUsingTokenInvalidTokenError, ) +from supertokens_python.types import RecipeUserId + from ...interfaces import ( APIInterface, APIOptions, @@ -25,12 +27,12 @@ async def handle_user_email_verify_put( user_context: Dict[str, Any], ) -> UserEmailVerifyPutAPIResponse: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") verified = request_body.get("verified") - if user_id is None or not isinstance(user_id, str): + if recipe_user_id is None or not isinstance(recipe_user_id, str): raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + "Required parameter 'recipeUserId' is missing or has an invalid type" ) if verified is None or not isinstance(verified, bool): @@ -40,7 +42,10 @@ async def handle_user_email_verify_put( if verified: token_response = await create_email_verification_token( - tenant_id=tenant_id, user_id=user_id, email=None, user_context=user_context + tenant_id=tenant_id, + recipe_user_id=RecipeUserId(recipe_user_id), + email=None, + user_context=user_context, ) if isinstance( @@ -57,6 +62,6 @@ async def handle_user_email_verify_put( raise Exception("Should not come here") else: - await unverify_email(user_id, user_context=user_context) + await unverify_email(RecipeUserId(recipe_user_id), user_context=user_context) return UserEmailVerifyPutAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py index b00b0a84e..8707053d6 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_email_verify_token_post.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Union +from supertokens_python.asyncio import get_user from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailverification.asyncio import ( @@ -13,6 +14,8 @@ UserEmailVerifyTokenPostAPIEmailAlreadyVerifiedErrorResponse, ) +from supertokens_python.types import RecipeUserId + async def handle_email_verify_token_post( _api_interface: APIInterface, @@ -24,15 +27,24 @@ async def handle_email_verify_token_post( UserEmailVerifyTokenPostAPIEmailAlreadyVerifiedErrorResponse, ]: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") - if user_id is None or not isinstance(user_id, str): + if recipe_user_id is None or not isinstance(recipe_user_id, str): raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + "Required parameter 'recipeUserId' is missing or has an invalid type" ) + user = await get_user(recipe_user_id, user_context) + + if user is None: + raise_bad_input_exception("User not found") + res = await send_email_verification_email( - tenant_id=tenant_id, user_id=user_id, email=None, user_context=user_context + tenant_id=tenant_id, + user_id=user.id, + recipe_user_id=RecipeUserId(recipe_user_id), + email=None, + user_context=user_context, ) if isinstance(res, SendEmailVerificationEmailAlreadyVerifiedError): diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py index 0d8350b0b..9d0dcef3e 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_get.py @@ -1,7 +1,10 @@ from typing import Union, Dict, Any +from supertokens_python.asyncio import get_user from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.recipe.dashboard.utils import get_user_for_recipe_id +from supertokens_python.recipe.dashboard.utils import ( + UserWithMetadata, +) from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.usermetadata.asyncio import get_user_metadata @@ -10,9 +13,7 @@ APIOptions, UserGetAPINoUserFoundError, UserGetAPIOkResponse, - UserGetAPIRecipeNotInitialisedError, ) -from ...utils import is_recipe_initialised, is_valid_recipe_id async def handle_user_get( @@ -20,45 +21,31 @@ async def handle_user_get( _tenant_id: str, api_options: APIOptions, _user_context: Dict[str, Any], -) -> Union[ - UserGetAPINoUserFoundError, - UserGetAPIOkResponse, - UserGetAPIRecipeNotInitialisedError, -]: +) -> Union[UserGetAPINoUserFoundError, UserGetAPIOkResponse,]: user_id = api_options.request.get_query_param("userId") - recipe_id = api_options.request.get_query_param("recipeId") if user_id is None: raise_bad_input_exception("Missing required parameter 'userId'") - if recipe_id is None: - raise_bad_input_exception("Missing required parameter 'recipeId'") - - if not is_valid_recipe_id(recipe_id): - raise_bad_input_exception("Invalid recipe id") - - if not is_recipe_initialised(recipe_id): - return UserGetAPIRecipeNotInitialisedError() - - user_response = await get_user_for_recipe_id(user_id, recipe_id) + user_response = await get_user(user_id, _user_context) if user_response is None: return UserGetAPINoUserFoundError() - user = user_response.user + user_with_metadata: UserWithMetadata = UserWithMetadata().from_user(user_response) try: UserMetadataRecipe.get_instance() except Exception: - user.first_name = "FEATURE_NOT_ENABLED" - user.last_name = "FEATURE_NOT_ENABLED" + user_with_metadata.first_name = "FEATURE_NOT_ENABLED" + user_with_metadata.last_name = "FEATURE_NOT_ENABLED" - return UserGetAPIOkResponse(recipe_id, user) + return UserGetAPIOkResponse(user_with_metadata) user_metadata = await get_user_metadata(user_id, user_context=_user_context) first_name = user_metadata.metadata.get("first_name", "") last_name = user_metadata.metadata.get("last_name", "") - user.first_name = first_name - user.last_name = last_name + user_with_metadata.first_name = first_name + user_with_metadata.last_name = last_name - return UserGetAPIOkResponse(recipe_id, user) + return UserGetAPIOkResponse(user_with_metadata) diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py index fd9dcee3e..1ed63365c 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_password_put.py @@ -1,23 +1,12 @@ -from typing import Any, Callable, Dict, List, Union +from typing import Any, Dict, Union from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.emailpassword import EmailPasswordRecipe -from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_token as ep_create_reset_password_token, -) -from supertokens_python.recipe.emailpassword.asyncio import ( - reset_password_using_token as ep_reset_password_using_token, -) -from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_PASSWORD_ID from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordOkResult, - CreateResetPasswordWrongUserIdError, - ResetPasswordUsingTokenInvalidTokenError, - ResetPasswordUsingTokenOkResult, + PasswordPolicyViolationError, + UnknownUserIdError, ) -from supertokens_python.recipe.emailpassword.types import NormalisedFormField - -from supertokens_python.utils import Awaitable +from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, @@ -34,71 +23,33 @@ async def handle_user_password_put( user_context: Dict[str, Any], ) -> Union[UserPasswordPutAPIResponse, UserPasswordPutAPIInvalidPasswordErrorResponse]: request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id = request_body.get("userId") + recipe_user_id = request_body.get("recipeUserId") new_password = request_body.get("newPassword") - if user_id is None or not isinstance(user_id, str): - raise_bad_input_exception("Missing required parameter 'userId'") + if recipe_user_id is None or not isinstance(recipe_user_id, str): + raise_bad_input_exception("Missing required parameter 'recipeUserId'") if new_password is None or not isinstance(new_password, str): raise_bad_input_exception("Missing required parameter 'newPassword'") - async def reset_password( - form_fields: List[NormalisedFormField], - create_reset_password_token: Callable[ - [str, str, Dict[str, Any]], - Awaitable[ - Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError] - ], - ], - reset_password_using_token: Callable[ - [str, str, str, Dict[str, Any]], - Awaitable[ - Union[ - ResetPasswordUsingTokenOkResult, - ResetPasswordUsingTokenInvalidTokenError, - ] - ], - ], - ) -> Union[ - UserPasswordPutAPIResponse, UserPasswordPutAPIInvalidPasswordErrorResponse - ]: - password_form_field = [ - field for field in form_fields if field.id == FORM_FIELD_PASSWORD_ID - ][0] - - password_validation_error = await password_form_field.validate( - new_password, tenant_id + email_password_recipe = EmailPasswordRecipe.get_instance() + update_response = ( + await email_password_recipe.recipe_implementation.update_email_or_password( + recipe_user_id=RecipeUserId(recipe_user_id), + email=None, + password=new_password, + apply_password_policy=True, + tenant_id_for_password_policy=tenant_id, + user_context=user_context, ) + ) - if password_validation_error is not None: - return UserPasswordPutAPIInvalidPasswordErrorResponse( - password_validation_error - ) - - password_reset_token = await create_reset_password_token( - tenant_id, user_id, user_context - ) - - if isinstance(password_reset_token, CreateResetPasswordWrongUserIdError): - # Techincally it can but its an edge case so we assume that it wont - # UNKNOWN_USER_ID_ERROR - raise Exception("Should never come here") - - password_reset_response = await reset_password_using_token( - tenant_id, password_reset_token.token, new_password, user_context + if isinstance(update_response, PasswordPolicyViolationError): + return UserPasswordPutAPIInvalidPasswordErrorResponse( + error=update_response.failure_reason ) - if isinstance( - password_reset_response, ResetPasswordUsingTokenInvalidTokenError - ): - # RESET_PASSWORD_INVALID_TOKEN_ERROR - raise Exception("Should not come here") + if isinstance(update_response, UnknownUserIdError): + raise Exception("Should never come here") - return UserPasswordPutAPIResponse() - - return await reset_password( - EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields, - ep_create_reset_password_token, - ep_reset_password_using_token, - ) + return UserPasswordPutAPIResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py index 25b4b6360..ff8b47215 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_put.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_put.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union +from typing_extensions import Literal from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.dashboard.utils import ( get_user_for_recipe_id, - is_valid_recipe_id, ) from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.asyncio import ( @@ -11,159 +11,241 @@ ) from supertokens_python.recipe.emailpassword.constants import FORM_FIELD_EMAIL_ID from supertokens_python.recipe.emailpassword.interfaces import ( - UpdateEmailOrPasswordEmailAlreadyExistsError, + EmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, +) +from supertokens_python.recipe.passwordless import ( + ContactEmailOnlyConfig, + ContactEmailOrPhoneConfig, + ContactPhoneOnlyConfig, + PasswordlessRecipe, ) -from supertokens_python.recipe.passwordless import PasswordlessRecipe from supertokens_python.recipe.passwordless.asyncio import ( update_user as pless_update_user, ) from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserEmailAlreadyExistsError as EmailAlreadyExistsErrorResponse, -) -from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserPhoneNumberAlreadyExistsError as PhoneNumberAlreadyExistsError, -) -from supertokens_python.recipe.passwordless.interfaces import ( - UpdateUserUnknownUserIdError as PlessUpdateUserUnknownUserIdError, + UpdateUserEmailAlreadyExistsError, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ) from supertokens_python.recipe.passwordless.utils import ( - ContactEmailOnlyConfig, - ContactEmailOrPhoneConfig, - ContactPhoneOnlyConfig, default_validate_email, default_validate_phone_number, ) from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.usermetadata.asyncio import update_user_metadata +from supertokens_python.types import RecipeUserId from ...interfaces import ( APIInterface, APIOptions, - UserPutAPIEmailAlreadyExistsErrorResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutAPIOkResponse, - UserPutPhoneAlreadyExistsAPIResponse, + APIResponse, ) +class OkResponse(APIResponse): + status: Literal["OK"] + + def __init__(self): + self.status = "OK" + + def to_json(self): + return {"status": self.status} + + +class EmailAlreadyExistsErrorResponse(APIResponse): + status: Literal["EMAIL_ALREADY_EXISTS_ERROR"] + + def __init__(self): + self.status = "EMAIL_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class InvalidEmailErrorResponse(APIResponse): + status: Literal["INVALID_EMAIL_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "INVALID_EMAIL_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class PhoneAlreadyExistsErrorResponse(APIResponse): + status: Literal["PHONE_ALREADY_EXISTS_ERROR"] + + def __init__(self): + self.status = "PHONE_ALREADY_EXISTS_ERROR" + + def to_json(self): + return {"status": self.status} + + +class InvalidPhoneErrorResponse(APIResponse): + status: Literal["INVALID_PHONE_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "INVALID_PHONE_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class EmailChangeNotAllowedErrorResponse(APIResponse): + status: Literal["EMAIL_CHANGE_NOT_ALLOWED_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "EMAIL_CHANGE_NOT_ALLOWED_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + +class PhoneNumberChangeNotAllowedErrorResponse(APIResponse): + status: Literal["PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR"] + error: str + + def __init__(self, error: str): + self.status = "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR" + self.error = error + + def to_json(self): + return {"status": self.status, "error": self.error} + + async def update_email_for_recipe_id( recipe_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: str, tenant_id: str, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIEmailAlreadyExistsErrorResponse, + OkResponse, + InvalidEmailErrorResponse, + EmailAlreadyExistsErrorResponse, + EmailChangeNotAllowedErrorResponse, ]: - validation_error: Optional[str] = None - if recipe_id == "emailpassword": - form_fields = ( - EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields - ) email_form_fields = [ - form_field - for form_field in form_fields - if form_field.id == FORM_FIELD_EMAIL_ID + field + for field in EmailPasswordRecipe.get_instance().config.sign_up_feature.form_fields + if field.id == FORM_FIELD_EMAIL_ID ] validation_error = await email_form_fields[0].validate(email, tenant_id) if validation_error is not None: - return UserPutAPIInvalidEmailErrorResponse(validation_error) + return InvalidEmailErrorResponse(validation_error) email_update_response = await ep_update_email_or_password( - user_id, email, user_context=user_context + recipe_user_id, email=email, user_context=user_context ) - if isinstance( - email_update_response, UpdateEmailOrPasswordEmailAlreadyExistsError + if isinstance(email_update_response, EmailAlreadyExistsError): + return EmailAlreadyExistsErrorResponse() + elif isinstance( + email_update_response, UpdateEmailOrPasswordEmailChangeNotAllowedError ): - return UserPutAPIEmailAlreadyExistsErrorResponse() + return EmailChangeNotAllowedErrorResponse(email_update_response.reason) - return UserPutAPIOkResponse() + return OkResponse() if recipe_id == "passwordless": - validation_error = None - - passwordless_config = PasswordlessRecipe.get_instance().config.contact_config + passwordless_config = PasswordlessRecipe.get_instance().config - if isinstance(passwordless_config.contact_method, ContactPhoneOnlyConfig): + if isinstance(passwordless_config.contact_config, ContactPhoneOnlyConfig): validation_error = await default_validate_email(email, tenant_id) - - elif isinstance( - passwordless_config, (ContactEmailOnlyConfig, ContactEmailOrPhoneConfig) - ): - validation_error = await passwordless_config.validate_email_address( - email, tenant_id - ) + else: + if isinstance( + passwordless_config.contact_config, + (ContactEmailOnlyConfig, ContactEmailOrPhoneConfig), + ): + validation_error = ( + await passwordless_config.contact_config.validate_email_address( + email, tenant_id + ) + ) + else: + raise Exception("Should never come here") if validation_error is not None: - return UserPutAPIInvalidEmailErrorResponse(validation_error) + return InvalidEmailErrorResponse(validation_error) update_result = await pless_update_user( - user_id, email, user_context=user_context + recipe_user_id, email=email, user_context=user_context ) - if isinstance(update_result, PlessUpdateUserUnknownUserIdError): + if isinstance(update_result, UpdateUserUnknownUserIdError): raise Exception("Should never come here") + elif isinstance(update_result, UpdateUserEmailAlreadyExistsError): + return EmailAlreadyExistsErrorResponse() + elif isinstance( + update_result, + ( + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + ), + ): + return EmailChangeNotAllowedErrorResponse(update_result.reason) - if isinstance(update_result, EmailAlreadyExistsErrorResponse): - return UserPutAPIEmailAlreadyExistsErrorResponse() - - return UserPutAPIOkResponse() + return OkResponse() # If it comes here then the user is a third party user in which case the UI should not have allowed this raise Exception("Should never come here") async def update_phone_for_recipe_id( - recipe_id: str, - user_id: str, + recipe_user_id: RecipeUserId, phone: str, tenant_id: str, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutPhoneAlreadyExistsAPIResponse, + OkResponse, + InvalidPhoneErrorResponse, + PhoneAlreadyExistsErrorResponse, + PhoneNumberChangeNotAllowedErrorResponse, ]: - validation_error: Optional[str] = None - - if recipe_id == "passwordless": - validation_error = None - - passwordless_config = PasswordlessRecipe.get_instance().config.contact_config - - if isinstance(passwordless_config, ContactEmailOnlyConfig): - validation_error = await default_validate_phone_number(phone, tenant_id) - elif isinstance( - passwordless_config, (ContactPhoneOnlyConfig, ContactEmailOrPhoneConfig) - ): - validation_error = await passwordless_config.validate_phone_number( + passwordless_config = PasswordlessRecipe.get_instance().config + + if isinstance(passwordless_config.contact_config, ContactEmailOnlyConfig): + validation_error = await default_validate_phone_number(phone, tenant_id) + elif isinstance( + passwordless_config.contact_config, + (ContactPhoneOnlyConfig, ContactEmailOrPhoneConfig), + ): + validation_error = ( + await passwordless_config.contact_config.validate_phone_number( phone, tenant_id ) - - if validation_error is not None: - return UserPutAPIInvalidPhoneErrorResponse(validation_error) - - update_result = await pless_update_user( - user_id, phone_number=phone, user_context=user_context ) + else: + raise Exception("Invalid contact config") - if isinstance(update_result, PlessUpdateUserUnknownUserIdError): - raise Exception("Should never come here") + if validation_error is not None: + return InvalidPhoneErrorResponse(validation_error) - if isinstance(update_result, PhoneNumberAlreadyExistsError): - return UserPutPhoneAlreadyExistsAPIResponse() + update_result = await pless_update_user( + recipe_user_id, phone_number=phone, user_context=user_context + ) - return UserPutAPIOkResponse() + if isinstance(update_result, UpdateUserUnknownUserIdError): + raise Exception("Should never come here") + elif isinstance(update_result, UpdateUserPhoneNumberAlreadyExistsError): + return PhoneAlreadyExistsErrorResponse() + elif isinstance(update_result, PhoneNumberChangeNotAllowedError): + return PhoneNumberChangeNotAllowedErrorResponse(update_result.reason) - # If it comes here then the user is a third party user in which case the UI should not have allowed this - raise Exception("Should never come here") + return OkResponse() async def handle_user_put( @@ -172,66 +254,63 @@ async def handle_user_put( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - UserPutAPIOkResponse, - UserPutAPIInvalidEmailErrorResponse, - UserPutAPIEmailAlreadyExistsErrorResponse, - UserPutAPIInvalidPhoneErrorResponse, - UserPutPhoneAlreadyExistsAPIResponse, + OkResponse, + InvalidEmailErrorResponse, + EmailAlreadyExistsErrorResponse, + InvalidPhoneErrorResponse, + PhoneAlreadyExistsErrorResponse, + EmailChangeNotAllowedErrorResponse, + PhoneNumberChangeNotAllowedErrorResponse, ]: - request_body: Dict[str, Any] = await api_options.request.json() # type: ignore - user_id: Optional[str] = request_body.get("userId") - recipe_id: Optional[str] = request_body.get("recipeId") - first_name: Optional[str] = request_body.get("firstName") - last_name: Optional[str] = request_body.get("lastName") - email: Optional[str] = request_body.get("email") - phone: Optional[str] = request_body.get("phone") - - if not isinstance(user_id, str): - return raise_bad_input_exception( - "Required parameter 'userId' is missing or has an invalid type" + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + recipe_user_id = request_body.get("recipeUserId") + recipe_id = request_body.get("recipeId") + first_name = request_body.get("firstName") + last_name = request_body.get("lastName") + email = request_body.get("email") + phone = request_body.get("phone") + + if not isinstance(recipe_user_id, str): + raise_bad_input_exception( + "Required parameter 'recipeUserId' is missing or has an invalid type" ) if not isinstance(recipe_id, str): - return raise_bad_input_exception( + raise_bad_input_exception( "Required parameter 'recipeId' is missing or has an invalid type" ) - if not is_valid_recipe_id(recipe_id): - raise_bad_input_exception("Invalid recipe id") - - if first_name is None and not isinstance(first_name, str): + if not isinstance(first_name, str): raise_bad_input_exception( "Required parameter 'firstName' is missing or has an invalid type" ) - if last_name is None and not isinstance(last_name, str): + if not isinstance(last_name, str): raise_bad_input_exception( "Required parameter 'lastName' is missing or has an invalid type" ) - if email is None and not isinstance(email, str): + if not isinstance(email, str): raise_bad_input_exception( "Required parameter 'email' is missing or has an invalid type" ) - if phone is None and not isinstance(phone, str): + if not isinstance(phone, str): raise_bad_input_exception( "Required parameter 'phone' is missing or has an invalid type" ) - user_response = await get_user_for_recipe_id(user_id, recipe_id) + user_response = await get_user_for_recipe_id( + RecipeUserId(recipe_user_id), recipe_id, user_context + ) - if user_response is None: + if user_response.user is None or user_response.recipe is None: raise Exception("Should never come here") - first_name = first_name.strip() - last_name = last_name.strip() - email = email.strip() - phone = phone.strip() - - if first_name != "" or last_name != "": + if first_name.strip() or last_name.strip(): is_recipe_initialized = False - try: UserMetadataRecipe.get_instance() is_recipe_initialized = True @@ -241,28 +320,37 @@ async def handle_user_put( if is_recipe_initialized: metadata_update: Dict[str, Any] = {} - if first_name != "": - metadata_update["first_name"] = first_name + if first_name.strip(): + metadata_update["first_name"] = first_name.strip() - if last_name != "": - metadata_update["last_name"] = last_name + if last_name.strip(): + metadata_update["last_name"] = last_name.strip() - await update_user_metadata(user_id, metadata_update, user_context) + await update_user_metadata( + user_response.user.user.id, metadata_update, user_context + ) - if email != "": + if email.strip(): email_update_response = await update_email_for_recipe_id( - user_response.recipe, user_id, email, tenant_id, user_context + user_response.recipe, + RecipeUserId(recipe_user_id), + email.strip(), + tenant_id, + user_context, ) - if not isinstance(email_update_response, UserPutAPIOkResponse): + if not isinstance(email_update_response, OkResponse): return email_update_response - if phone != "": + if phone.strip(): phone_update_response = await update_phone_for_recipe_id( - user_response.recipe, user_id, phone, tenant_id, user_context + RecipeUserId(recipe_user_id), + phone.strip(), + tenant_id, + user_context, ) - if not isinstance(phone_update_response, UserPutAPIOkResponse): + if not isinstance(phone_update_response, OkResponse): return phone_update_response - return UserPutAPIOkResponse() + return OkResponse() diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py index 40b4e82b9..bea9de4bd 100644 --- a/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_sessions_get.py @@ -28,9 +28,7 @@ async def handle_sessions_get( # Passing tenant id as None sets fetch_across_all_tenants to True # which is what we want here. - session_handles = await get_all_session_handles_for_user( - user_id, None, user_context - ) + session_handles = await get_all_session_handles_for_user(user_id) sessions: List[Optional[SessionInfo]] = [None for _ in session_handles] async def call_(i: int, session_handle: str): diff --git a/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py b/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py new file mode 100644 index 000000000..2fef96569 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userdetails/user_unlink_get.py @@ -0,0 +1,31 @@ +from typing import Any, Dict +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.accountlinking.asyncio import unlink_account +from supertokens_python.recipe.dashboard.utils import RecipeUserId +from ...interfaces import APIInterface, APIOptions +from supertokens_python.types import APIResponse + + +class UserUnlinkGetOkResult(APIResponse): + def __init__(self): + self.status: Literal["OK"] = "OK" + + def to_json(self): + return {"status": self.status} + + +async def handle_user_unlink_get( + _api_interface: APIInterface, + _tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> UserUnlinkGetOkResult: + recipe_user_id = api_options.request.get_query_param("recipeUserId") + + if recipe_user_id is None: + raise_bad_input_exception("Required field recipeUserId is missing") + + await unlink_account(RecipeUserId(recipe_user_id), user_context) + + return UserUnlinkGetOkResult() diff --git a/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py b/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py new file mode 100644 index 000000000..39f9b9d44 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/add_role_to_user.py @@ -0,0 +1,71 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import add_role_to_user +from supertokens_python.recipe.userroles.interfaces import AddRoleToUserOkResult +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_user_already_have_role: bool): + self.status: Literal["OK"] = "OK" + self.did_user_already_have_role = did_user_already_have_role + + def to_json(self): + return { + "status": self.status, + "didUserAlreadyHaveRole": self.did_user_already_have_role, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_ROLE_ERROR"] = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def add_role_to_user_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + user_id = request_body.get("userId") + role = request_body.get("role") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if user_id is None or not isinstance(user_id, str): + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await add_role_to_user(tenant_id, user_id, role) + + if isinstance(response, AddRoleToUserOkResult): + return OkResponse( + did_user_already_have_role=response.did_user_already_have_role + ) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py new file mode 100644 index 000000000..3006c1b97 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/get_role_to_user.py @@ -0,0 +1,46 @@ +from typing import Any, Union, List +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_roles_for_user +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, roles: List[str]): + self.status: Literal["OK"] = "OK" + self.roles = roles + + def to_json(self): + return { + "status": self.status, + "roles": self.roles, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_roles_for_user_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + user_id = api_options.request.get_query_param("userId") + + if user_id is None: + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await get_roles_for_user(tenant_id, user_id) + return OkResponse(roles=response.roles) diff --git a/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py new file mode 100644 index 000000000..bab61e045 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/permissions/get_permissions_for_role.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, Union, List +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_permissions_for_role +from supertokens_python.recipe.userroles.interfaces import ( + GetPermissionsForRoleOkResult, +) + +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkPermissionsForRoleResponse(APIResponse): + def __init__(self, permissions: List[str]): + self.status = "OK" + self.permissions = permissions + + def to_json(self): + return {"status": self.status, "permissions": self.permissions} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_permissions_for_role_api( + _api_interface: APIInterface, + _tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[ + OkPermissionsForRoleResponse, + FeatureNotEnabledErrorResponse, + UnknownRoleErrorResponse, +]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception("Required parameter 'role' is missing") + + response = await get_permissions_for_role(role) + + if isinstance(response, GetPermissionsForRoleOkResult): + return OkPermissionsForRoleResponse(response.permissions) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py b/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py new file mode 100644 index 000000000..c6a479311 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/permissions/remove_permissions_from_role.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List, Union +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import remove_permissions_from_role +from supertokens_python.recipe.userroles.interfaces import ( + RemovePermissionsFromRoleOkResult, +) +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self): + self.status = "OK" + + def to_json(self): + return {"status": self.status} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def remove_permissions_from_role_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Dict[str, Any] +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + role = request_body.get("role") + permissions: Union[List[str], None] = request_body.get("permissions") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if permissions is None: + raise_bad_input_exception( + "Required parameter 'permissions' is missing or has an invalid type" + ) + + response = await remove_permissions_from_role(role, permissions) + if isinstance(response, RemovePermissionsFromRoleOkResult): + return OkResponse() + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py b/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py new file mode 100644 index 000000000..b110d79de --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/remove_user_role.py @@ -0,0 +1,65 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import remove_user_role +from supertokens_python.recipe.userroles.interfaces import RemoveUserRoleOkResult +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_user_have_role: bool): + self.status: Literal["OK"] = "OK" + self.did_user_have_role = did_user_have_role + + def to_json(self): + return { + "status": self.status, + "didUserHaveRole": self.did_user_have_role, + } + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +class UnknownRoleErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_ROLE_ERROR"] = "UNKNOWN_ROLE_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def remove_user_role_api( + _: APIInterface, tenant_id: str, api_options: APIOptions, __: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse, UnknownRoleErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + user_id = api_options.request.get_query_param("userId") + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if user_id is None: + raise_bad_input_exception( + "Required parameter 'userId' is missing or has an invalid type" + ) + + response = await remove_user_role(tenant_id, user_id, role) + + if isinstance(response, RemoveUserRoleOkResult): + return OkResponse(did_user_have_role=response.did_user_have_role) + else: + return UnknownRoleErrorResponse() diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py b/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py new file mode 100644 index 000000000..e5a89b4fb --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/create_role_or_add_permissions.py @@ -0,0 +1,54 @@ +from typing import Any, List, Union +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import ( + create_new_role_or_add_permissions, +) +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, created_new_role: bool): + self.status = "OK" + self.created_new_role = created_new_role + + def to_json(self): + return {"status": self.status, "createdNewRole": self.created_new_role} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def create_role_or_add_permissions_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + request_body = await api_options.request.json() + if request_body is None: + raise_bad_input_exception("Request body is missing") + + role = request_body.get("role") + permissions: Union[List[str], None] = request_body.get("permissions") + + if role is None or not isinstance(role, str): + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + if permissions is None: + raise_bad_input_exception( + "Required parameter 'permissions' is missing or has an invalid type" + ) + + response = await create_new_role_or_add_permissions(role, permissions) + return OkResponse(created_new_role=response.created_new_role) diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py b/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py new file mode 100644 index 000000000..77816d72c --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/delete_role.py @@ -0,0 +1,43 @@ +from typing import Any, Union +from typing_extensions import Literal +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import delete_role +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, did_role_exist: bool): + self.status: Literal["OK"] = "OK" + self.did_role_exist = did_role_exist + + def to_json(self): + return {"status": self.status, "didRoleExist": self.did_role_exist} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def delete_role_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + role = api_options.request.get_query_param("role") + + if role is None: + raise_bad_input_exception( + "Required parameter 'role' is missing or has an invalid type" + ) + + response = await delete_role(role) + return OkResponse(did_role_exist=response.did_role_exist) diff --git a/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py new file mode 100644 index 000000000..71dc11d21 --- /dev/null +++ b/supertokens_python/recipe/dashboard/api/userroles/roles/get_all_roles.py @@ -0,0 +1,35 @@ +from typing import Any, Union, List +from typing_extensions import Literal +from supertokens_python.recipe.dashboard.interfaces import APIInterface, APIOptions +from supertokens_python.recipe.userroles.asyncio import get_all_roles +from supertokens_python.recipe.userroles.recipe import UserRolesRecipe +from supertokens_python.types import APIResponse + + +class OkResponse(APIResponse): + def __init__(self, roles: List[str]): + self.status: Literal["OK"] = "OK" + self.roles = roles + + def to_json(self): + return {"status": self.status, "roles": self.roles} + + +class FeatureNotEnabledErrorResponse(APIResponse): + def __init__(self): + self.status: Literal["FEATURE_NOT_ENABLED_ERROR"] = "FEATURE_NOT_ENABLED_ERROR" + + def to_json(self): + return {"status": self.status} + + +async def get_all_roles_api( + _: APIInterface, __: str, api_options: APIOptions, ___: Any +) -> Union[OkResponse, FeatureNotEnabledErrorResponse]: + try: + UserRolesRecipe.get_instance() + except Exception: + return FeatureNotEnabledErrorResponse() + + response = await get_all_roles() + return OkResponse(roles=response.roles) diff --git a/supertokens_python/recipe/dashboard/api/users_get.py b/supertokens_python/recipe/dashboard/api/users_get.py index 949e569da..5e9c0ae7b 100644 --- a/supertokens_python/recipe/dashboard/api/users_get.py +++ b/supertokens_python/recipe/dashboard/api/users_get.py @@ -15,9 +15,6 @@ import asyncio from typing import TYPE_CHECKING, Any, Awaitable, List, Dict -from typing_extensions import Literal - -from supertokens_python.supertokens import Supertokens from ...usermetadata import UserMetadataRecipe from ...usermetadata.asyncio import get_user_metadata @@ -29,9 +26,9 @@ APIOptions, APIInterface, ) - from supertokens_python.types import APIResponse from supertokens_python.exceptions import GeneralError, raise_bad_input_exception +from supertokens_python.asyncio import get_users_newest_first, get_users_oldest_first async def handle_users_get_api( @@ -39,38 +36,40 @@ async def handle_users_get_api( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], -) -> APIResponse: +) -> DashboardUsersGetResponse: _ = api_implementation limit = api_options.request.get_query_param("limit") if limit is None: raise_bad_input_exception("Missing required parameter 'limit'") - time_joined_order: Literal["ASC", "DESC"] = api_options.request.get_query_param( # type: ignore - "timeJoinedOrder", "DESC" - ) + time_joined_order = api_options.request.get_query_param("timeJoinedOrder", "DESC") if time_joined_order not in ["ASC", "DESC"]: - raise_bad_input_exception("Invalid value recieved for 'timeJoinedOrder'") + raise_bad_input_exception("Invalid value received for 'timeJoinedOrder'") pagination_token = api_options.request.get_query_param("paginationToken") + query = get_search_params_from_url(api_options.request.get_original_url()) - users_response = await Supertokens.get_instance().get_users( + users_response = await ( + get_users_newest_first + if time_joined_order == "DESC" + else get_users_oldest_first + )( tenant_id, - time_joined_order=time_joined_order, limit=int(limit), pagination_token=pagination_token, - include_recipe_ids=None, - query=api_options.request.get_query_params(), + query=query, user_context=user_context, ) - # user metadata bulk fetch with batches: - try: UserMetadataRecipe.get_instance() except GeneralError: + users_with_metadata: List[UserWithMetadata] = [ + UserWithMetadata().from_user(user) for user in users_response.users + ] return DashboardUsersGetResponse( - users_response.users, users_response.next_pagination_token + users_with_metadata, users_response.next_pagination_token ) users_with_metadata: List[UserWithMetadata] = [ @@ -80,15 +79,13 @@ async def handle_users_get_api( async def get_user_metadata_and_update_user(user_idx: int) -> None: user = users_response.users[user_idx] - user_metadata = await get_user_metadata(user.user_id, user_context) + user_metadata = await get_user_metadata(user.id) first_name = user_metadata.metadata.get("first_name") last_name = user_metadata.metadata.get("last_name") - # None becomes null which is acceptable for the dashboard. users_with_metadata[user_idx].first_name = first_name users_with_metadata[user_idx].last_name = last_name - # Batch calls to get user metadata: for i, _ in enumerate(users_response.users): metadata_fetch_awaitables.append(get_user_metadata_and_update_user(i)) @@ -108,10 +105,22 @@ async def get_user_metadata_and_update_user(user_idx: int) -> None: ) ] await asyncio.gather(*promises_to_call) - promise_arr_start_position += batch_size return DashboardUsersGetResponse( users_with_metadata, users_response.next_pagination_token, ) + + +def get_search_params_from_url(path: str) -> Dict[str, str]: + from urllib.parse import urlparse, parse_qs + + url_object = urlparse("https://example.com" + path) + params = parse_qs(url_object.query) + search_query = { + key: value[0] + for key, value in params.items() + if key not in ["limit", "timeJoinedOrder", "paginationToken"] + } + return search_query diff --git a/supertokens_python/recipe/dashboard/constants.py b/supertokens_python/recipe/dashboard/constants.py index 258718254..7f7b9287d 100644 --- a/supertokens_python/recipe/dashboard/constants.py +++ b/supertokens_python/recipe/dashboard/constants.py @@ -12,4 +12,16 @@ SIGN_OUT_API = "/api/signout" SEARCH_TAGS_API = "/api/search/tags" DASHBOARD_ANALYTICS_API = "/api/analytics" -TENANTS_LIST_API = "/api/tenants/list" +TENANT_THIRD_PARTY_CONFIG_API = "/api/thirdparty/config" +TENANT_API = "/api/tenant" +LIST_TENANTS_WITH_LOGIN_METHODS = "/api/tenants" +UPDATE_TENANT_CORE_CONFIG_API = "/api/tenant/core-config" +UPDATE_TENANT_FIRST_FACTOR_API = "/api/tenant/first-factor" +UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API = "/api/tenant/required-secondary-factor" +CREATE_EMAIL_PASSWORD_USER = "/api/user/emailpassword" +CREATE_PASSWORDLESS_USER = "/api/user/passwordless" +UNLINK_USER = "/api/user/unlink" +USERROLES_PERMISSIONS_API = "/api/userroles/role/permissions" +USERROLES_REMOVE_PERMISSIONS_API = "/api/userroles/role/permissions/remove" +USERROLES_ROLE_API = "/api/userroles/role" +USERROLES_USER_API = "/api/userroles/user/roles" diff --git a/supertokens_python/recipe/dashboard/interfaces.py b/supertokens_python/recipe/dashboard/interfaces.py index 20033301b..9a3435488 100644 --- a/supertokens_python/recipe/dashboard/interfaces.py +++ b/supertokens_python/recipe/dashboard/interfaces.py @@ -15,8 +15,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union - -from supertokens_python.types import User +from typing_extensions import Literal +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig from ...types import APIResponse @@ -27,12 +27,6 @@ from supertokens_python.recipe.session.interfaces import SessionInformationResult from supertokens_python.framework import BaseRequest, BaseResponse - from supertokens_python.recipe.multitenancy.interfaces import ( - EmailPasswordConfig, - PasswordlessConfig, - ThirdPartyConfig, - ) - class SessionInfo: def __init__(self, info: SessionInformationResult) -> None: @@ -94,7 +88,7 @@ class DashboardUsersGetResponse(APIResponse): def __init__( self, - users: Union[List[User], List[UserWithMetadata]], + users: List[UserWithMetadata], next_pagination_token: Optional[str], ): self.users = users @@ -109,28 +103,14 @@ def to_json(self) -> Dict[str, Any]: class DashboardListTenantItem: - def __init__( - self, - tenant_id: str, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - ): - self.tenant_id = tenant_id - self.emailpassword = emailpassword - self.passwordless = passwordless - self.third_party = third_party + def __init__(self, tenant_config: TenantConfig): + self.tenant_config = tenant_config - def to_json(self): - res = { - "tenantId": self.tenant_id, - "emailPassword": self.emailpassword.to_json(), - "passwordless": self.passwordless.to_json(), - "thirdParty": self.third_party.to_json(), + def to_json(self) -> Dict[str, Any]: + return { + "tenantId": self.tenant_config.tenant_id, } - return res - class DashboardListTenantsGetResponse(APIResponse): status: str = "OK" @@ -158,14 +138,12 @@ def to_json(self) -> Dict[str, Any]: class UserGetAPIOkResponse(APIResponse): status: str = "OK" - def __init__(self, recipe_id: str, user: UserWithMetadata): - self.recipe_id = recipe_id + def __init__(self, user: UserWithMetadata): self.user = user def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "recipeId": self.recipe_id, "user": self.user.to_json(), } @@ -177,13 +155,6 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class UserGetAPIRecipeNotInitialisedError(APIResponse): - status: str = "RECIPE_NOT_INITIALISED" - - def to_json(self) -> Dict[str, Any]: - return {"status": self.status} - - class FeatureNotEnabledError(APIResponse): status: str = "FEATURE_NOT_ENABLED_ERROR" @@ -210,7 +181,6 @@ def __init__(self, sessions: List[SessionInfo]): "accessTokenPayload": s.access_token_payload, "expiry": s.expiry, "sessionDataInDatabase": s.session_data_in_database, - "tenantId": s.tenant_id, "timeCreated": s.time_created, "userId": s.user_id, "sessionHandle": s.session_handle, @@ -355,3 +325,60 @@ class AnalyticsResponse(APIResponse): def to_json(self) -> Dict[str, Any]: return {"status": self.status} + + +class CoreConfigFieldInfo: + def __init__( + self, + key: str, + value_type: Literal["string", "boolean", "number"], + value: Union[str, int, float, bool, None], + description: str, + is_different_across_tenants: bool, + possible_values: Union[List[str], None] = None, + is_nullable: bool = False, + default_value: Union[str, int, float, bool, None] = None, + is_plugin_property: bool = False, + is_plugin_property_editable: bool = False, + ): + self.key = key + self.value_type = value_type + self.value = value + self.description = description + self.is_different_across_tenants = is_different_across_tenants + self.possible_values = possible_values + self.is_nullable = is_nullable + self.default_value = default_value + self.is_plugin_property = is_plugin_property + self.is_plugin_property_editable = is_plugin_property_editable + + def to_json(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "key": self.key, + "valueType": self.value_type, + "value": self.value, + "description": self.description, + "isDifferentAcrossTenants": self.is_different_across_tenants, + "isNullable": self.is_nullable, + "defaultValue": self.default_value, + "isPluginProperty": self.is_plugin_property, + "isPluginPropertyEditable": self.is_plugin_property_editable, + } + if self.possible_values is not None: + result["possibleValues"] = self.possible_values + return result + + @staticmethod + def from_json(json: Dict[str, Any]) -> CoreConfigFieldInfo: + return CoreConfigFieldInfo( + key=json["key"], + value_type=json["valueType"], + value=json["value"], + description=json["description"], + is_different_across_tenants=json["isDifferentAcrossTenants"], + possible_values=json["possibleValues"], + is_nullable=json["isNullable"], + default_value=json["defaultValue"], + is_plugin_property=json["isPluginProperty"], + is_plugin_property_editable=json["isPluginPropertyEditable"], + ) diff --git a/supertokens_python/recipe/dashboard/recipe.py b/supertokens_python/recipe/dashboard/recipe.py index 5a5bd7b8f..93912f5ee 100644 --- a/supertokens_python/recipe/dashboard/recipe.py +++ b/supertokens_python/recipe/dashboard/recipe.py @@ -17,6 +17,69 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Dict, Any from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.dashboard.api.multitenancy.create_or_update_third_party_config import ( + handle_create_or_update_third_party_config, +) +from supertokens_python.recipe.dashboard.api.multitenancy.create_tenant import ( + create_tenant, +) +from supertokens_python.recipe.dashboard.api.multitenancy.delete_tenant import ( + delete_tenant_api, +) +from supertokens_python.recipe.dashboard.api.multitenancy.delete_third_party_config import ( + delete_third_party_config_api, +) +from supertokens_python.recipe.dashboard.api.multitenancy.get_tenant_info import ( + get_tenant_info, +) +from supertokens_python.recipe.dashboard.api.multitenancy.get_third_party_config import ( + get_third_party_config, +) +from supertokens_python.recipe.dashboard.api.multitenancy.list_all_tenants_with_login_methods import ( + list_all_tenants_with_login_methods, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_core_config import ( + update_tenant_core_config, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_first_factor import ( + update_tenant_first_factor, +) +from supertokens_python.recipe.dashboard.api.multitenancy.update_tenant_secondary_factor import ( + update_tenant_secondary_factor, +) +from supertokens_python.recipe.dashboard.api.user.create.emailpassword_user import ( + create_email_password_user, +) +from supertokens_python.recipe.dashboard.api.user.create.passwordless_user import ( + create_passwordless_user, +) +from supertokens_python.recipe.dashboard.api.userdetails.user_unlink_get import ( + handle_user_unlink_get, +) +from supertokens_python.recipe.dashboard.api.userroles.add_role_to_user import ( + add_role_to_user_api, +) +from supertokens_python.recipe.dashboard.api.userroles.get_role_to_user import ( + get_roles_for_user_api, +) +from supertokens_python.recipe.dashboard.api.userroles.permissions.get_permissions_for_role import ( + get_permissions_for_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.permissions.remove_permissions_from_role import ( + remove_permissions_from_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.remove_user_role import ( + remove_user_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.create_role_or_add_permissions import ( + create_role_or_add_permissions_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.delete_role import ( + delete_role_api, +) +from supertokens_python.recipe.dashboard.api.userroles.roles.get_all_roles import ( + get_all_roles_api, +) from supertokens_python.recipe_module import APIHandled, RecipeModule from .api import ( @@ -40,7 +103,6 @@ handle_users_count_get_api, handle_users_get_api, handle_validate_key_api, - handle_list_tenants_api, ) from .api.implementation import APIImplementation from .exceptions import SuperTokensDashboardError @@ -71,7 +133,19 @@ USERS_COUNT_API, USERS_LIST_GET_API, VALIDATE_KEY_API, - TENANTS_LIST_API, + TENANT_THIRD_PARTY_CONFIG_API, + TENANT_API, + LIST_TENANTS_WITH_LOGIN_METHODS, + UPDATE_TENANT_CORE_CONFIG_API, + UPDATE_TENANT_FIRST_FACTOR_API, + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, + CREATE_EMAIL_PASSWORD_USER, + CREATE_PASSWORDLESS_USER, + UNLINK_USER, + USERROLES_PERMISSIONS_API, + USERROLES_REMOVE_PERMISSIONS_API, + USERROLES_ROLE_API, + USERROLES_USER_API, ) from .utils import ( InputOverrideConfig, @@ -247,9 +321,161 @@ def get_apis_handled(self) -> List[APIHandled]: False, ), APIHandled( - NormalisedURLPath(get_api_path_with_dashboard_base(TENANTS_LIST_API)), + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), + "post", + TENANT_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), + "delete", + TENANT_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(TENANT_API)), "get", - TENANTS_LIST_API, + TENANT_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "put", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "delete", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(TENANT_THIRD_PARTY_CONFIG_API) + ), + "get", + TENANT_THIRD_PARTY_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(LIST_TENANTS_WITH_LOGIN_METHODS) + ), + "get", + LIST_TENANTS_WITH_LOGIN_METHODS, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(UPDATE_TENANT_CORE_CONFIG_API) + ), + "put", + UPDATE_TENANT_CORE_CONFIG_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(UPDATE_TENANT_FIRST_FACTOR_API) + ), + "put", + UPDATE_TENANT_FIRST_FACTOR_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base( + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API + ) + ), + "put", + UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(CREATE_EMAIL_PASSWORD_USER) + ), + "post", + CREATE_EMAIL_PASSWORD_USER, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(CREATE_PASSWORDLESS_USER) + ), + "post", + CREATE_PASSWORDLESS_USER, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(UNLINK_USER)), + "get", + UNLINK_USER, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_PERMISSIONS_API) + ), + "get", + USERROLES_PERMISSIONS_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_PERMISSIONS_API) + ), + "put", + USERROLES_PERMISSIONS_API, + False, + ), + APIHandled( + NormalisedURLPath( + get_api_path_with_dashboard_base(USERROLES_REMOVE_PERMISSIONS_API) + ), + "put", + USERROLES_REMOVE_PERMISSIONS_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "put", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "delete", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_ROLE_API)), + "get", + USERROLES_ROLE_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "get", + USERROLES_USER_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "put", + USERROLES_USER_API, + False, + ), + APIHandled( + NormalisedURLPath(get_api_path_with_dashboard_base(USERROLES_USER_API)), + "delete", + USERROLES_USER_API, False, ), ] @@ -329,8 +555,52 @@ async def handle_api_request( elif request_id == DASHBOARD_ANALYTICS_API: if method == "post": api_function = handle_analytics_post - elif request_id == TENANTS_LIST_API: - api_function = handle_list_tenants_api + elif request_id == TENANT_API: + if method == "post": + api_function = create_tenant + if method == "delete": + api_function = delete_tenant_api + if method == "get": + api_function = get_tenant_info + elif request_id == TENANT_THIRD_PARTY_CONFIG_API: + if method == "put": + api_function = handle_create_or_update_third_party_config + if method == "delete": + api_function = delete_third_party_config_api + if method == "get": + api_function = get_third_party_config + elif request_id == LIST_TENANTS_WITH_LOGIN_METHODS: + api_function = list_all_tenants_with_login_methods + elif request_id == UPDATE_TENANT_CORE_CONFIG_API: + api_function = update_tenant_core_config + elif request_id == UPDATE_TENANT_FIRST_FACTOR_API: + api_function = update_tenant_first_factor + elif request_id == UPDATE_TENANT_REQUIRED_SECONDARY_FACTOR_API: + api_function = update_tenant_secondary_factor + elif request_id == CREATE_EMAIL_PASSWORD_USER: + api_function = create_email_password_user + elif request_id == CREATE_PASSWORDLESS_USER: + api_function = create_passwordless_user + elif request_id == UNLINK_USER: + api_function = handle_user_unlink_get + elif request_id == USERROLES_PERMISSIONS_API: + api_function = get_permissions_for_role_api + elif request_id == USERROLES_REMOVE_PERMISSIONS_API: + api_function = remove_permissions_from_role_api + elif request_id == USERROLES_ROLE_API: + if method == "put": + api_function = create_role_or_add_permissions_api + if method == "delete": + api_function = delete_role_api + if method == "get": + api_function = get_all_roles_api + elif request_id == USERROLES_USER_API: + if method == "get": + api_function = get_roles_for_user_api + if method == "put": + api_function = add_role_to_user_api + if method == "delete": + api_function = remove_user_role_api if api_function is not None: return await api_key_protector( diff --git a/supertokens_python/recipe/dashboard/utils.py b/supertokens_python/recipe/dashboard/utils.py index 209d93eb4..e05e53848 100644 --- a/supertokens_python/recipe/dashboard/utils.py +++ b/supertokens_python/recipe/dashboard/utils.py @@ -14,24 +14,17 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, List +from typing_extensions import Literal +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest from supertokens_python.recipe.emailpassword import EmailPasswordRecipe -from supertokens_python.recipe.emailpassword.asyncio import ( - get_user_by_id as ep_get_user_by_id, -) from supertokens_python.recipe.passwordless import PasswordlessRecipe -from supertokens_python.recipe.passwordless.asyncio import ( - get_user_by_id as pless_get_user_by_id, -) from supertokens_python.recipe.thirdparty import ThirdPartyRecipe -from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_id as tp_get_user_by_idx, -) -from supertokens_python.types import User -from supertokens_python.utils import Awaitable, log_debug_message, normalise_email +from supertokens_python.types import User, RecipeUserId +from supertokens_python.utils import log_debug_message, normalise_email from ...normalised_url_path import NormalisedURLPath from .constants import ( @@ -56,15 +49,9 @@ class UserWithMetadata: - user_id: str - time_joined: int - recipe_id: Optional[str] = None - email: Optional[str] = None - phone_number: Optional[str] = None - tp_info: Optional[Dict[str, Any]] = None + user: User first_name: Optional[str] = None last_name: Optional[str] = None - tenant_ids: List[str] def from_user( self, @@ -72,78 +59,15 @@ def from_user( first_name: Optional[str] = None, last_name: Optional[str] = None, ): - self.first_name = first_name - self.last_name = last_name - - self.user_id = user.user_id - # from_user() is called in /api/users (note extra s) - # here we DashboardUsersGetResponse() doesn't maintain - # recipe id for each user on its own. That's why we need - # to set self.recipe_id here. - self.recipe_id = user.recipe_id - self.time_joined = user.time_joined - self.email = user.email - self.phone_number = user.phone_number - self.tp_info = ( - None if user.third_party_info is None else user.third_party_info.__dict__ - ) - self.tenant_ids = user.tenant_ids - - return self - - def from_dict( - self, - user_obj_dict: Dict[str, Any], - first_name: Optional[str] = None, - last_name: Optional[str] = None, - ): - self.first_name = first_name - self.last_name = last_name - - self.user_id = user_obj_dict["user_id"] - # from_dict() is used in `/api/user` where - # recipe_id is already passed seperately to - # GetUserForRecipeIdResult object - # So we set recipe_id to None here - self.recipe_id = None - self.time_joined = user_obj_dict["time_joined"] - self.tenant_ids = user_obj_dict.get("tenant_ids", []) - - self.email = user_obj_dict.get("email") - self.phone_number = user_obj_dict.get("phone_number") - self.tp_info = ( - None - if user_obj_dict.get("third_party_info") is None - else user_obj_dict["third_party_info"].__dict__ - ) - + self.first_name = first_name or "" + self.last_name = last_name or "" + self.user = user return self def to_json(self) -> Dict[str, Any]: - user_json: Dict[str, Any] = { - "id": self.user_id, - "timeJoined": self.time_joined, - "tenantIds": self.tenant_ids, - } - if self.tp_info is not None: - user_json["thirdParty"] = { - "id": self.tp_info["id"], - "userId": self.tp_info["user_id"], - } - if self.phone_number is not None: - user_json["phoneNumber"] = self.phone_number - if self.email is not None: - user_json["email"] = self.email - if self.first_name is not None: - user_json["firstName"] = self.first_name - if self.last_name is not None: - user_json["lastName"] = self.last_name - - if self.recipe_id is not None: - return { - "recipeId": self.recipe_id, - "user": user_json, - } + user_json = self.user.to_json() + user_json["firstName"] = self.first_name + user_json["lastName"] = self.last_name return user_json @@ -259,103 +183,81 @@ def get_api_if_matched(path: NormalisedURLPath, method: str) -> Optional[str]: return None -def is_valid_recipe_id(recipe_id: str) -> bool: - return recipe_id in ("emailpassword", "thirdparty", "passwordless") +class GetUserForRecipeIdHelperResult: + def __init__(self, user: Optional[User] = None, recipe: Optional[str] = None): + self.user = user + self.recipe = recipe class GetUserForRecipeIdResult: - def __init__(self, user: UserWithMetadata, recipe: str): + def __init__( + self, user: Optional[UserWithMetadata] = None, recipe: Optional[str] = None + ): self.user = user self.recipe = recipe -if TYPE_CHECKING: - from supertokens_python.recipe.emailpassword.types import User as EmailPasswordUser - from supertokens_python.recipe.passwordless.types import User as PasswordlessUser - from supertokens_python.recipe.thirdparty.types import User as ThirdPartyUser - - GetUserResult = Union[ - EmailPasswordUser, - ThirdPartyUser, - PasswordlessUser, - None, - ] - - async def get_user_for_recipe_id( - user_id: str, recipe_id: str -) -> Optional[GetUserForRecipeIdResult]: - user: Optional[UserWithMetadata] = None - recipe: Optional[str] = None - - async def update_user_dict( - get_user_funcs: List[Callable[[str], Awaitable[GetUserResult]]], - recipes: List[str], - ): - nonlocal user, user_id, recipe - - for get_user_func, recipe_id in zip(get_user_funcs, recipes): - try: - recipe_user = await get_user_func(user_id) # type: ignore - - if recipe_user is not None: - user = UserWithMetadata().from_dict( - recipe_user.__dict__, first_name="", last_name="" - ) - recipe = recipe_id - break - except Exception: - pass + recipe_user_id: RecipeUserId, recipe_id: str, user_context: Dict[str, Any] +) -> GetUserForRecipeIdResult: + user_response = await _get_user_for_recipe_id( + recipe_user_id, recipe_id, user_context + ) - if recipe_id == EmailPasswordRecipe.recipe_id: - await update_user_dict( - [ep_get_user_by_id], - ["emailpassword"], + user = None + if user_response.user is not None: + user = UserWithMetadata().from_user( + user_response.user, first_name="", last_name="" ) - elif recipe_id == ThirdPartyRecipe.recipe_id: - await update_user_dict( - [tp_get_user_by_idx], - ["thirdparty"], - ) + return GetUserForRecipeIdResult(user=user, recipe=user_response.recipe) - elif recipe_id == PasswordlessRecipe.recipe_id: - await update_user_dict( - [pless_get_user_by_id], - ["passwordless"], - ) - if user is not None and recipe is not None: - return GetUserForRecipeIdResult(user, recipe) +async def _get_user_for_recipe_id( + recipe_user_id: RecipeUserId, recipe_id: str, user_context: Dict[str, Any] +) -> GetUserForRecipeIdHelperResult: + recipe: Optional[Literal["emailpassword", "thirdparty", "passwordless"]] = None - return None + user = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context + ) + if user is None: + return GetUserForRecipeIdHelperResult(user=None, recipe=None) -def is_recipe_initialised(recipeId: str) -> bool: - isRecipeInitialised: bool = False + login_method = next( + ( + m + for m in user.login_methods + if m.recipe_id == recipe_id + and m.recipe_user_id.get_as_string() == recipe_user_id.get_as_string() + ), + None, + ) + + if login_method is None: + return GetUserForRecipeIdHelperResult(user=None, recipe=None) - if recipeId == EmailPasswordRecipe.recipe_id: + if recipe_id == EmailPasswordRecipe.recipe_id: try: EmailPasswordRecipe.get_instance() - isRecipeInitialised = True + recipe = "emailpassword" except Exception: pass - - elif recipeId == PasswordlessRecipe.recipe_id: + elif recipe_id == ThirdPartyRecipe.recipe_id: try: - PasswordlessRecipe.get_instance() - isRecipeInitialised = True + ThirdPartyRecipe.get_instance() + recipe = "thirdparty" except Exception: pass - - elif recipeId == ThirdPartyRecipe.recipe_id: + elif recipe_id == PasswordlessRecipe.recipe_id: try: - ThirdPartyRecipe.get_instance() - isRecipeInitialised = True + PasswordlessRecipe.get_instance() + recipe = "passwordless" except Exception: pass - return isRecipeInitialised + return GetUserForRecipeIdHelperResult(user=user, recipe=recipe) async def validate_api_key( diff --git a/supertokens_python/recipe/emailpassword/api/implementation.py b/supertokens_python/recipe/emailpassword/api/implementation.py index 54f90fd04..39490b324 100644 --- a/supertokens_python/recipe/emailpassword/api/implementation.py +++ b/supertokens_python/recipe/emailpassword/api/implementation.py @@ -13,41 +13,62 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + get_authenticating_user_and_add_to_current_tenant_if_required, + is_fake_email, + post_auth_checks, + pre_auth_checks, +) from supertokens_python.logger import log_debug_message +from supertokens_python.recipe.accountlinking import ( + AccountInfoWithRecipeIdAndUserId, + ShouldNotAutomaticallyLink, +) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId from supertokens_python.recipe.emailpassword.constants import ( FORM_FIELD_EMAIL_ID, FORM_FIELD_PASSWORD_ID, ) from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface, - CreateResetPasswordWrongUserIdError, + CreateResetPasswordOkResult, + EmailAlreadyExistsError, EmailExistsGetOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, GeneratePasswordResetTokenPostOkResult, - PasswordResetPostInvalidTokenResponse, + LinkingToSessionUserFailedError, + PasswordPolicyViolationError, PasswordResetPostOkResult, - ResetPasswordUsingTokenInvalidTokenError, + PasswordResetTokenInvalidError, + SignInOkResult, + SignInPostNotAllowedResponse, SignInPostOkResult, - SignInPostWrongCredentialsError, - SignInWrongCredentialsError, - SignUpEmailAlreadyExistsError, - SignUpPostEmailAlreadyExistsError, + SignUpOkResult, + SignUpPostNotAllowedResponse, SignUpPostOkResult, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + WrongCredentialsError, ) from supertokens_python.recipe.emailpassword.types import ( + EmailTemplateVars, FormField, - PasswordResetEmailTemplateVars, PasswordResetEmailTemplateVarsUser, ) +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.totp.types import UnknownUserIdError from ..utils import get_password_reset_link -from supertokens_python.recipe.session.asyncio import create_new_session -from supertokens_python.utils import find_first_occurrence_in_list if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import APIOptions -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import AccountInfo, GeneralErrorResponse, RecipeUserId class APIImplementation(APIInterface): @@ -58,10 +79,23 @@ async def email_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[EmailExistsGetOkResult, GeneralErrorResponse]: - user = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + # Check if there exists an email password user with the same email + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, ) - return EmailExistsGetOkResult(user is not None) + + email_password_user_exists = any( + any( + lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + for lm in user.login_methods + ) + for user in users + ) + + return EmailExistsGetOkResult(exists=email_password_user_exists) async def generate_password_reset_token_post( self, @@ -69,52 +103,214 @@ async def generate_password_reset_token_post( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[GeneratePasswordResetTokenPostOkResult, GeneralErrorResponse]: - emailFormField = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: + email = next(f.value for f in form_fields if f.id == "email") + + async def generate_and_send_password_reset_token( + primary_user_id: str, recipe_user_id: Optional[RecipeUserId] + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: + user_id = ( + recipe_user_id.get_as_string() if recipe_user_id else primary_user_id + ) + response = ( + await api_options.recipe_implementation.create_reset_password_token( + tenant_id=tenant_id, + user_id=user_id, + email=email, + user_context=user_context, + ) + ) + if isinstance(response, UnknownUserIdError): + log_debug_message( + f"Password reset email not sent, unknown user id: {user_id}" + ) + return GeneratePasswordResetTokenPostOkResult() + + assert isinstance(response, CreateResetPasswordOkResult) + password_reset_link = get_password_reset_link( + app_info=api_options.app_info, + token=response.token, + tenant_id=tenant_id, + request=api_options.request, + user_context=user_context, + ) + + log_debug_message(f"Sending password reset email to {email}") + await api_options.email_delivery.ingredient_interface_impl.send_email( + EmailTemplateVars( + user=PasswordResetEmailTemplateVarsUser( + user_id=primary_user_id, + recipe_user_id=recipe_user_id, + email=email, + ), + password_reset_link=password_reset_link, + tenant_id=tenant_id, + ), + user_context=user_context, + ) + + return GeneratePasswordResetTokenPostOkResult() + + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, ) - if emailFormField is None: - raise Exception("Should never come here") - email = emailFormField.value - user = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + email_password_account = next( + ( + lm + for user in users + for lm in user.login_methods + if lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + ), + None, ) - if user is None: - return GeneratePasswordResetTokenPostOkResult() + linking_candidate = next((u for u in users if u.is_primary_user), None) - token_result = ( - await api_options.recipe_implementation.create_reset_password_token( - user.user_id, tenant_id, user_context - ) + # first we check if there even exists a primary user that has the input email + log_debug_message( + f"generatePasswordResetTokenPOST: primary linking candidate: {linking_candidate.id if linking_candidate else None}" + ) + log_debug_message( + f"generatePasswordResetTokenPOST: linking candidate count {len(users)}" ) - if isinstance(token_result, CreateResetPasswordWrongUserIdError): + # If there is no existing primary user and there is a single option to link + # we see if that user can become primary (and a candidate for linking) + if linking_candidate is None and len(users) > 0: + # If the only user that exists with this email is a non-primary emailpassword user, then we can just let them reset their password, because: + # we are not going to link anything and there is no risk of account takeover. + if ( + email_password_account is not None + and len(users) == 1 + and users[0].login_methods[0].recipe_user_id.get_as_string() + == email_password_account.recipe_user_id.get_as_string() + ): + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, + ) + + oldest_user = min(users, key=lambda u: u.time_joined) log_debug_message( - "Password reset email not sent, unknown user id: %s", user.user_id + f"generatePasswordResetTokenPOST: oldest recipe level-linking candidate: {oldest_user.id} (w/ {'verified' if oldest_user.login_methods[0].verified else 'unverified'} email)" + ) + # Otherwise, we check if the user can become primary. + should_become_primary_user = ( + await AccountLinkingRecipe.get_instance().should_become_primary_user( + oldest_user, tenant_id, None, user_context + ) ) - return GeneratePasswordResetTokenPostOkResult() - password_reset_link = get_password_reset_link( - api_options.app_info, - token_result.token, - tenant_id, - api_options.request, - user_context, + log_debug_message( + f"generatePasswordResetTokenPOST: recipe level-linking candidate {'can' if should_become_primary_user else 'can not'} become primary" + ) + if should_become_primary_user: + linking_candidate = oldest_user + + if linking_candidate is None: + if email_password_account is None: + log_debug_message( + f"Password reset email not sent, unknown user email: {email}" + ) + return GeneratePasswordResetTokenPostOkResult() + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, + ) + + email_verified = any( + lm.has_same_email_as(email) and lm.verified + for lm in linking_candidate.login_methods ) - log_debug_message("Sending password reset email to %s", email) - send_email_input = PasswordResetEmailTemplateVars( - user=PasswordResetEmailTemplateVarsUser(user.user_id, user.email), - password_reset_link=password_reset_link, - tenant_id=tenant_id, + has_other_email_or_phone = any( + (lm.email is not None and not lm.has_same_email_as(email)) + or lm.phone_number is not None + for lm in linking_candidate.login_methods ) - await api_options.email_delivery.ingredient_interface_impl.send_email( - send_email_input, user_context + + if not email_verified and has_other_email_or_phone: + return GeneratePasswordResetTokenPostNotAllowedResponse( + "Reset password link was not created because of account take over risk. Please contact support. (ERR_CODE_001)" + ) + + if linking_candidate.is_primary_user and email_password_account is not None: + # If a primary user has the input email as verified or has no other emails then it is always allowed to reset their own password: + # - there is no risk of account takeover, because they have verified this email or haven't linked it to anything else (checked above this block) + # - there will be no linking as a result of this action, so we do not need to check for linking (checked here by seeing that the two accounts are already linked) + are_the_two_accounts_linked = any( + lm.recipe_user_id.get_as_string() + == email_password_account.recipe_user_id.get_as_string() + for lm in linking_candidate.login_methods + ) + + if are_the_two_accounts_linked: + return await generate_and_send_password_reset_token( + linking_candidate.id, email_password_account.recipe_user_id + ) + + should_do_account_linking_response = await AccountLinkingRecipe.get_instance().config.should_do_automatic_account_linking( + AccountInfoWithRecipeIdAndUserId.from_account_info_or_login_method( + email_password_account + or AccountInfoWithRecipeId(email=email, recipe_id="emailpassword") + ), + linking_candidate, + None, + tenant_id, + user_context, ) - return GeneratePasswordResetTokenPostOkResult() + if email_password_account is None: + if isinstance( + should_do_account_linking_response, ShouldNotAutomaticallyLink + ): + log_debug_message( + "Password reset email not sent, since email password user didn't exist, and account linking not enabled" + ) + return GeneratePasswordResetTokenPostOkResult() + + is_sign_up_allowed = ( + await AccountLinkingRecipe.get_instance().is_sign_up_allowed( + new_user=AccountInfoWithRecipeId( + email=email, recipe_id="emailpassword" + ), + is_verified=True, + session=None, + tenant_id=tenant_id, + user_context=user_context, + ) + ) + if is_sign_up_allowed: + return await generate_and_send_password_reset_token( + linking_candidate.id, None + ) + else: + log_debug_message( + f"Password reset email not sent, is_sign_up_allowed returned false for email: {email}" + ) + return GeneratePasswordResetTokenPostOkResult() + + if isinstance(should_do_account_linking_response, ShouldNotAutomaticallyLink): + return await generate_and_send_password_reset_token( + email_password_account.recipe_user_id.get_as_string(), + email_password_account.recipe_user_id, + ) + + return await generate_and_send_password_reset_token( + linking_candidate.id, email_password_account.recipe_user_id + ) async def password_reset_post( self, @@ -125,103 +321,452 @@ async def password_reset_post( user_context: Dict[str, Any], ) -> Union[ PasswordResetPostOkResult, - PasswordResetPostInvalidTokenResponse, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, GeneralErrorResponse, ]: - new_password_for_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields + async def mark_email_as_verified(recipe_user_id: RecipeUserId, email: str): + email_verification_instance = ( + EmailVerificationRecipe.get_instance_optional() + ) + if email_verification_instance: + token_response = await email_verification_instance.recipe_implementation.create_email_verification_token( + tenant_id=tenant_id, + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, + ) + + if token_response.status == "OK": + await email_verification_instance.recipe_implementation.verify_email_using_token( + tenant_id=tenant_id, + token=token_response.token, + attempt_account_linking=False, + user_context=user_context, + ) + + async def do_update_password_and_verify_email_and_try_link_if_not_primary( + recipe_user_id: RecipeUserId, + ): + update_response = ( + await api_options.recipe_implementation.update_email_or_password( + tenant_id_for_password_policy=tenant_id, + email=None, + recipe_user_id=recipe_user_id, + password=new_password, + apply_password_policy=None, + user_context=user_context, + ) + ) + + if isinstance( + update_response, + ( + EmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + ), + ): + raise Exception("Should never happen") + if isinstance(update_response, UnknownUserIdError): + return PasswordResetTokenInvalidError() + elif isinstance(update_response, PasswordPolicyViolationError): + return update_response + else: + await mark_email_as_verified( + recipe_user_id, email_for_whom_token_was_generated + ) + updated_user_after_email_verification = await get_user( + recipe_user_id.get_as_string(), user_context + ) + if updated_user_after_email_verification is None: + raise Exception( + "Should never happen - user deleted after during password reset" + ) + + if updated_user_after_email_verification.is_primary_user: + return PasswordResetPostOkResult( + user=updated_user_after_email_verification, + email=email_for_whom_token_was_generated, + ) + + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=updated_user_after_email_verification, + session=None, + user_context=user_context, + ) + user_after_we_tried_linking = ( + link_res.user + if link_res.status == "OK" + else updated_user_after_email_verification + ) + + assert user_after_we_tried_linking is not None + + return PasswordResetPostOkResult( + user=user_after_we_tried_linking, + email=email_for_whom_token_was_generated, + ) + + new_password = next(f.value for f in form_fields if f.id == "password") + + token_consumption_response = ( + await api_options.recipe_implementation.consume_password_reset_token( + token=token, + tenant_id=tenant_id, + user_context=user_context, + ) + ) + + if isinstance(token_consumption_response, PasswordResetTokenInvalidError): + return PasswordResetTokenInvalidError() + + user_id_for_whom_token_was_generated = token_consumption_response.user_id + email_for_whom_token_was_generated = token_consumption_response.email + + existing_user = await get_user( + user_id_for_whom_token_was_generated, user_context ) - if new_password_for_field is None: - raise Exception("Should never come here") - new_password = new_password_for_field.value - result = await api_options.recipe_implementation.reset_password_using_token( - token, new_password, tenant_id, user_context + if existing_user is None: + return PasswordResetTokenInvalidError() + + token_generated_for_email_password_user = any( + lm.recipe_user_id.get_as_string() == user_id_for_whom_token_was_generated + and lm.recipe_id == "emailpassword" + for lm in existing_user.login_methods ) - if isinstance(result, ResetPasswordUsingTokenInvalidTokenError): - return PasswordResetPostInvalidTokenResponse() + if token_generated_for_email_password_user: + if not existing_user.is_primary_user: + # If this is a recipe level emailpassword user, we can always allow them to reset their password. + return await do_update_password_and_verify_email_and_try_link_if_not_primary( + RecipeUserId(user_id_for_whom_token_was_generated) + ) + + # If the user is a primary user resetting the password of an emailpassword user linked to it + # we need to check for account takeover risk (similar to what we do when generating the token) + + # We check if there is any login method in which the input email is verified. + # If that is the case, then it's proven that the user owns the email and we can + # trust linking of the email password account. + email_verified = any( + lm.has_same_email_as(email_for_whom_token_was_generated) and lm.verified + for lm in existing_user.login_methods + ) + + # finally, we check if the primary user has any other email / phone number + # associated with this account - and if it does, then it means that + # there is a risk of account takeover, so we do not allow the token to be generated + has_other_email_or_phone = any( + ( + lm.email is not None + and not lm.has_same_email_as(email_for_whom_token_was_generated) + ) + or lm.phone_number is not None + for lm in existing_user.login_methods + ) + + if not email_verified and has_other_email_or_phone: + # We can return an invalid token error, because in this case the token should not have been created + # whenever they try to re-create it they'll see the appropriate error message + return PasswordResetTokenInvalidError() - return PasswordResetPostOkResult(result.user_id) + # since this doesn't result in linking and there is no risk of account takeover, we can allow the password reset to proceed + return ( + await do_update_password_and_verify_email_and_try_link_if_not_primary( + RecipeUserId(user_id_for_whom_token_was_generated) + ) + ) + + create_user_response = ( + await api_options.recipe_implementation.create_new_recipe_user( + tenant_id=tenant_id, + email=token_consumption_response.email, + password=new_password, + user_context=user_context, + ) + ) + if isinstance(create_user_response, EmailAlreadyExistsError): + return PasswordResetTokenInvalidError() + else: + await mark_email_as_verified( + create_user_response.user.login_methods[0].recipe_user_id, + token_consumption_response.email, + ) + updated_user = await get_user( + create_user_response.user.id, + user_context, + ) + if updated_user is None: + raise Exception( + "Should never happen - user deleted after during password reset" + ) + create_user_response.user = updated_user + link_res = await AccountLinkingRecipe.get_instance().try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=create_user_response.user, + session=None, + user_context=user_context, + ) + user_after_linking = ( + link_res.user if link_res.status == "OK" else create_user_response.user + ) + assert user_after_linking is not None + return PasswordResetPostOkResult( + user=user_after_linking, + email=token_consumption_response.email, + ) async def sign_in_post( self, form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignInPostOkResult, SignInPostWrongCredentialsError, GeneralErrorResponse + SignInPostOkResult, + WrongCredentialsError, + SignInPostNotAllowedResponse, + GeneralErrorResponse, ]: - password_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields + error_code_map = { + "SIGN_IN_NOT_ALLOWED": "Cannot sign in due to security reasons. Please try resetting your password, use a different login method or contact support. (ERR_CODE_008)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_009)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_010)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_011)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_012)", + }, + } + + email = next(f.value for f in form_fields if f.id == FORM_FIELD_EMAIL_ID) + password = next(f.value for f in form_fields if f.id == FORM_FIELD_PASSWORD_ID) + + recipe_id = "emailpassword" + + async def check_credentials_on_tenant(tenant_id: str) -> bool: + verify_result = await api_options.recipe_implementation.verify_credentials( + email=email, + password=password, + tenant_id=tenant_id, + user_context=user_context, + ) + return isinstance(verify_result, SignInOkResult) + + if is_fake_email(email) and session is None: + return WrongCredentialsError() + + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + email=email, + phone_number=None, + third_party=None, + user_context=user_context, + recipe_id=recipe_id, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials_on_tenant, + ) ) - if password_form_field is None: - raise Exception("Should never come here") - password = password_form_field.value - email_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields + is_verified = ( + authenticating_user is not None + and authenticating_user.login_method is not None + and authenticating_user.login_method.verified ) - if email_form_field is None: - raise Exception("Should never come here") - email = email_form_field.value - result = await api_options.recipe_implementation.sign_in( - email, password, tenant_id, user_context + if authenticating_user is None: + return WrongCredentialsError() + + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id=recipe_id, + email=email, + ), + factor_ids=["emailpassword"], + is_sign_up=False, + authenticating_user=authenticating_user.user, + is_verified=is_verified, + sign_in_verifies_login_method=False, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + should_try_linking_with_session_user=should_try_linking_with_session_user, + user_context=user_context, + session=session, ) - if isinstance(result, SignInWrongCredentialsError): - return SignInPostWrongCredentialsError() + if pre_auth_checks_result.status != "OK": + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + raise Exception("Should never happen") + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInPostNotAllowedResponse(reason=reason) - user = result.user - session = await create_new_session( + if is_fake_email(email) and pre_auth_checks_result.is_first_factor: + return WrongCredentialsError() + + sign_in_response = await api_options.recipe_implementation.sign_in( + email=email, + password=password, + session=session, tenant_id=tenant_id, - request=api_options.request, - user_id=user.user_id, - access_token_payload={}, - session_data_in_database={}, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if isinstance(sign_in_response, WrongCredentialsError): + return WrongCredentialsError() + if isinstance(sign_in_response, LinkingToSessionUserFailedError): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[sign_in_response.reason] + return SignInPostNotAllowedResponse(reason=reason) + + post_auth_checks_result = await post_auth_checks( + authenticated_user=sign_in_response.user, + recipe_user_id=sign_in_response.recipe_user_id, + is_sign_up=False, + factor_id="emailpassword", + session=session, + tenant_id=tenant_id, + user_context=user_context, + request=api_options.request, + ) + + if post_auth_checks_result.status != "OK": + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInPostNotAllowedResponse(reason) + + return SignInPostOkResult( + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, ) - return SignInPostOkResult(user, session) async def sign_up_post( self, form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignUpPostOkResult, SignUpPostEmailAlreadyExistsError, GeneralErrorResponse + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, ]: - password_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields - ) - if password_form_field is None: - raise Exception("Should never come here") - password = password_form_field.value + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign up due to security reasons. Please try logging in, use a different login method or contact support. (ERR_CODE_007)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_013)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_014)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_015)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_016)", + }, + } - email_form_field = find_first_occurrence_in_list( - lambda x: x.id == FORM_FIELD_EMAIL_ID, form_fields - ) - if email_form_field is None: - raise Exception("Should never come here") - email = email_form_field.value + email = next(f.value for f in form_fields if f.id == "email") + password = next(f.value for f in form_fields if f.id == "password") - result = await api_options.recipe_implementation.sign_up( - email, password, tenant_id, user_context + pre_auth_check_res = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="emailpassword", + email=email, + ), + factor_ids=["emailpassword"], + is_sign_up=True, + is_verified=is_fake_email(email), + sign_in_verifies_login_method=False, + skip_session_user_update_in_core=False, + authenticating_user=None, # since this is a sign up, this is None + tenant_id=tenant_id, + user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) - if isinstance(result, SignUpEmailAlreadyExistsError): - return SignUpPostEmailAlreadyExistsError() + if pre_auth_check_res.status == "SIGN_UP_NOT_ALLOWED": + conflicting_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo( + email=email, + ), + do_union_of_account_info=False, + user_context=user_context, + ) + if any( + any( + lm.recipe_id == "emailpassword" and lm.has_same_email_as(email) + for lm in u.login_methods + ) + for u in conflicting_users + ): + return EmailAlreadyExistsError() + + if pre_auth_check_res.status != "OK": + if isinstance(pre_auth_check_res, SignInNotAllowedResponse): + raise Exception("Should never happen") + if isinstance(pre_auth_check_res, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignUpPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_check_res.reason] + return SignUpPostNotAllowedResponse(reason=reason) - user = result.user - session = await create_new_session( + if is_fake_email(email) and pre_auth_check_res.is_first_factor: + # Fake emails cannot be used as a first factor + return EmailAlreadyExistsError() + + sign_up_response = await api_options.recipe_implementation.sign_up( tenant_id=tenant_id, + email=email, + password=password, + session=session, + user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if isinstance(sign_up_response, EmailAlreadyExistsError): + return sign_up_response + if not isinstance(sign_up_response, SignUpOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[sign_up_response.reason] + return SignUpPostNotAllowedResponse(reason=reason) + + post_auth_checks_res = await post_auth_checks( + authenticated_user=sign_up_response.user, + recipe_user_id=sign_up_response.recipe_user_id, + is_sign_up=True, + factor_id="emailpassword", + session=session, request=api_options.request, - user_id=user.user_id, - access_token_payload={}, - session_data_in_database={}, + tenant_id=tenant_id, user_context=user_context, ) - return SignUpPostOkResult(user, session) + + if post_auth_checks_res.status != "OK": + # this will fail cause error_code_map doesn't have SIGN_IN_NOT_ALLOWED + # but that's ok, cause it should never come here for sign up anyway. + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignUpPostNotAllowedResponse(reason) + + return SignUpPostOkResult( + user=post_auth_checks_res.user, + session=post_auth_checks_res.session, + ) diff --git a/supertokens_python/recipe/emailpassword/api/password_reset.py b/supertokens_python/recipe/emailpassword/api/password_reset.py index 8ae7cbc7a..af72d7a4b 100644 --- a/supertokens_python/recipe/emailpassword/api/password_reset.py +++ b/supertokens_python/recipe/emailpassword/api/password_reset.py @@ -14,6 +14,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.recipe.emailpassword.exceptions import ( + raise_form_field_exception, +) + +from supertokens_python.recipe.emailpassword.interfaces import ( + PasswordPolicyViolationError, +) +from supertokens_python.recipe.emailpassword.types import ErrorFormField if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import ( @@ -55,4 +63,14 @@ async def handle_password_reset_api( response = await api_implementation.password_reset_post( form_fields, token, tenant_id, api_options, user_context ) + if isinstance(response, PasswordPolicyViolationError): + return raise_form_field_exception( + "Error in input formFields", + [ + ErrorFormField( + id="password", + error=response.failure_reason, + ) + ], + ) return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/api/signin.py b/supertokens_python/recipe/emailpassword/api/signin.py index 6dfd770f8..7b99d5f36 100644 --- a/supertokens_python/recipe/emailpassword/api/signin.py +++ b/supertokens_python/recipe/emailpassword/api/signin.py @@ -14,6 +14,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed +from supertokens_python.recipe.emailpassword.interfaces import SignInPostOkResult if TYPE_CHECKING: from supertokens_python.recipe.emailpassword.interfaces import ( @@ -22,7 +24,11 @@ ) from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) from .utils import validate_form_fields_or_throw_error @@ -43,8 +49,40 @@ async def handle_sign_in_api( api_options.config.sign_in_feature.form_fields, form_fields_raw, tenant_id ) + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + if session is not None: + tenant_id = session.get_tenant_id() + response = await api_implementation.sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) + if isinstance(response, SignInPostOkResult): + return send_200_response( + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + }, + api_options.response, + ) + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/api/signup.py b/supertokens_python/recipe/emailpassword/api/signup.py index e8a1ce1f9..21b2ec9bc 100644 --- a/supertokens_python/recipe/emailpassword/api/signup.py +++ b/supertokens_python/recipe/emailpassword/api/signup.py @@ -14,8 +14,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed -from supertokens_python.recipe.emailpassword.interfaces import SignUpPostOkResult +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + SignUpPostOkResult, +) from supertokens_python.types import GeneralErrorResponse from ..exceptions import raise_form_field_exception @@ -28,7 +32,11 @@ ) from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) from .utils import validate_form_fields_or_throw_error @@ -49,21 +57,53 @@ async def handle_sign_up_api( api_options.config.sign_up_feature.form_fields, form_fields_raw, tenant_id ) + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + + if session is not None: + tenant_id = session.get_tenant_id() + response = await api_implementation.sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) if isinstance(response, SignUpPostOkResult): - return send_200_response(response.to_json(), api_options.response) + return send_200_response( + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=response.user, + session_container=response.session, + created_new_recipe_user=None, + user_context=user_context, + ), + }, + api_options.response, + ) if isinstance(response, GeneralErrorResponse): return send_200_response(response.to_json(), api_options.response) - return raise_form_field_exception( - "EMAIL_ALREADY_EXISTS_ERROR", - [ - ErrorFormField( - id="email", - error="This email already exists. Please sign in instead.", - ) - ], - ) + if isinstance(response, EmailAlreadyExistsError): + return raise_form_field_exception( + "EMAIL_ALREADY_EXISTS_ERROR", + [ + ErrorFormField( + id="email", + error="This email already exists. Please sign in instead.", + ) + ], + ) + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailpassword/asyncio/__init__.py b/supertokens_python/recipe/emailpassword/asyncio/__init__.py index a97f47367..5f4f62795 100644 --- a/supertokens_python/recipe/emailpassword/asyncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/asyncio/__init__.py @@ -12,19 +12,30 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict, Union, Optional +from typing_extensions import Literal from supertokens_python import get_request_from_user_context +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.emailpassword import EmailPasswordRecipe +from supertokens_python.recipe.session import SessionContainer -from ..types import EmailTemplateVars, User +from ..types import EmailTemplateVars from ...multitenancy.constants import DEFAULT_TENANT_ID +from ....types import RecipeUserId from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordWrongUserIdError, - CreateResetPasswordLinkUnknownUserIdError, - CreateResetPasswordLinkOkResult, - SendResetPasswordEmailOkResult, - SendResetPasswordEmailUnknownUserIdError, + CreateResetPasswordOkResult, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, + UpdateEmailOrPasswordOkResult, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, + SignUpOkResult, + EmailAlreadyExistsError, + SignInOkResult, + WrongCredentialsError, ) from supertokens_python.recipe.emailpassword.utils import get_password_reset_link from supertokens_python.recipe.emailpassword.types import ( @@ -33,55 +44,67 @@ ) -async def update_email_or_password( - user_id: str, - email: Union[str, None] = None, - password: Union[str, None] = None, - apply_password_policy: Union[bool, None] = None, - tenant_id_for_password_policy: Optional[str] = None, - user_context: Union[None, Dict[str, Any]] = None, -): +async def sign_up( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.update_email_or_password( - user_id, - email, - password, - apply_password_policy, - tenant_id_for_password_policy or DEFAULT_TENANT_ID, - user_context, + return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_up( + email=email, + password=password, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + session=session, + user_context=user_context, + should_try_linking_with_session_user=session is not None, ) -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[None, User]: +async def sign_in( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: if user_context is None: user_context = {} - return ( - await EmailPasswordRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id, user_context - ) + return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_in( + email=email, + password=password, + tenant_id=tenant_id or DEFAULT_TENANT_ID, + session=session, + user_context=user_context, + should_try_linking_with_session_user=session is not None, ) -async def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: +async def verify_credentials( + tenant_id: str, + email: str, + password: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.get_user_by_email( - email, tenant_id, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.verify_credentials( + email, password, tenant_id or DEFAULT_TENANT_ID, user_context ) async def create_reset_password_token( - tenant_id: str, user_id: str, user_context: Union[None, Dict[str, Any]] = None -): + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: if user_context is None: user_context = {} return await EmailPasswordRecipe.get_instance().recipe_implementation.create_reset_password_token( - user_id, tenant_id, user_context + user_id, email, tenant_id or DEFAULT_TENANT_ID, user_context ) @@ -89,91 +112,148 @@ async def reset_password_using_token( tenant_id: str, token: str, new_password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + PasswordPolicyViolationError, + PasswordResetTokenInvalidError, + UnknownUserIdError, +]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.reset_password_using_token( - token, new_password, tenant_id, user_context + consume_resp = await consume_password_reset_token(tenant_id, token, user_context) + if not isinstance(consume_resp, ConsumePasswordResetTokenOkResult): + return consume_resp + + result = await update_email_or_password( + recipe_user_id=RecipeUserId(consume_resp.user_id), + email=consume_resp.email, + password=new_password, + tenant_id_for_password_policy=tenant_id, + user_context=user_context, ) + if isinstance( + result, + (EmailAlreadyExistsError, UpdateEmailOrPasswordEmailChangeNotAllowedError), + ): + raise Exception("Should never happen") -async def sign_in( - tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_in( - email, password, tenant_id, user_context - ) + return result -async def sign_up( +async def consume_password_reset_token( tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + token: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().recipe_implementation.sign_up( - email, password, tenant_id, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.consume_password_reset_token( + token, tenant_id or DEFAULT_TENANT_ID, user_context ) -async def send_email( - input_: EmailTemplateVars, - user_context: Union[None, Dict[str, Any]] = None, -): +async def update_email_or_password( + recipe_user_id: RecipeUserId, + email: Optional[str] = None, + password: Optional[str] = None, + apply_password_policy: Optional[bool] = None, + tenant_id_for_password_policy: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +]: if user_context is None: user_context = {} - return await EmailPasswordRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( - input_, user_context + return await EmailPasswordRecipe.get_instance().recipe_implementation.update_email_or_password( + recipe_user_id, + email, + password, + apply_password_policy, + tenant_id_for_password_policy or DEFAULT_TENANT_ID, + user_context, ) async def create_reset_password_link( - tenant_id: str, user_id: str, user_context: Optional[Dict[str, Any]] = None -): + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[str, UnknownUserIdError]: if user_context is None: user_context = {} - token = await create_reset_password_token(tenant_id, user_id, user_context) - if isinstance(token, CreateResetPasswordWrongUserIdError): - return CreateResetPasswordLinkUnknownUserIdError() + token = await create_reset_password_token(tenant_id, user_id, email, user_context) + if isinstance(token, UnknownUserIdError): + return token recipe_instance = EmailPasswordRecipe.get_instance() request = get_request_from_user_context(user_context) - return CreateResetPasswordLinkOkResult( - link=get_password_reset_link( - recipe_instance.get_app_info(), - token.token, - tenant_id, - request, - user_context, - ) + return get_password_reset_link( + recipe_instance.get_app_info(), + token.token, + tenant_id or DEFAULT_TENANT_ID, + request, + user_context, ) async def send_reset_password_email( - tenant_id: str, user_id: str, user_context: Optional[Dict[str, Any]] = None -): - link = await create_reset_password_link(tenant_id, user_id, user_context) - if isinstance(link, CreateResetPasswordLinkUnknownUserIdError): - return SendResetPasswordEmailUnknownUserIdError() + tenant_id: str, + user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[Literal["UNKNOWN_USER_ID_ERROR"], Literal["OK"]]: + if user_context is None: + user_context = {} + + user = await get_user(user_id, user_context) + if user is None: + return "UNKNOWN_USER_ID_ERROR" + + login_method = next( + ( + m + for m in user.login_methods + if m.recipe_id == "emailpassword" and m.has_same_email_as(email) + ), + None, + ) + if login_method is None: + return "UNKNOWN_USER_ID_ERROR" - user = await get_user_by_id(user_id, user_context) - assert user is not None + link = await create_reset_password_link(tenant_id, user_id, email, user_context) + if isinstance(link, UnknownUserIdError): + return "UNKNOWN_USER_ID_ERROR" + assert login_method.email is not None await send_email( PasswordResetEmailTemplateVars( - PasswordResetEmailTemplateVarsUser(user.user_id, user.email), - link.link, - tenant_id, + user=PasswordResetEmailTemplateVarsUser( + user_id=user.id, + email=login_method.email, + recipe_user_id=login_method.recipe_user_id, + ), + password_reset_link=link, + tenant_id=tenant_id or DEFAULT_TENANT_ID, ), user_context, ) - return SendResetPasswordEmailOkResult() + return "OK" + + +async def send_email( + input_: EmailTemplateVars, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + return await EmailPasswordRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( + input_, user_context + ) diff --git a/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py b/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py index ee295789e..ab256d780 100644 --- a/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py +++ b/supertokens_python/recipe/emailpassword/emaildelivery/services/backward_compatibility/__init__.py @@ -24,13 +24,17 @@ EmailTemplateVars, RecipeInterface, ) -from supertokens_python.recipe.emailpassword.types import User +from supertokens_python.recipe.emailpassword.types import ( + PasswordResetEmailTemplateVarsUser, +) from supertokens_python.supertokens import AppInfo from supertokens_python.utils import handle_httpx_client_exceptions async def create_and_send_email_using_supertokens_service( - app_info: AppInfo, user: User, password_reset_url_with_token: str + app_info: AppInfo, + user: PasswordResetEmailTemplateVarsUser, + password_reset_url_with_token: str, ) -> None: if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): return @@ -66,19 +70,12 @@ async def send_email( template_vars: EmailTemplateVars, user_context: Dict[str, Any], ) -> None: - user = await self.recipe_interface_impl.get_user_by_id( - user_id=template_vars.user.id, user_context=user_context - ) - if user is None: - raise Exception("Should never come here") - # we add this here cause the user may have overridden the sendEmail function # to change the input email and if we don't do this, the input email # will get reset by the getUserById call above. - user.email = template_vars.user.email try: await create_and_send_email_using_supertokens_service( - self.app_info, user, template_vars.password_reset_link + self.app_info, template_vars.user, template_vars.password_reset_link ) except Exception: pass diff --git a/supertokens_python/recipe/emailpassword/interfaces.py b/supertokens_python/recipe/emailpassword/interfaces.py index 3567754fc..60fad3644 100644 --- a/supertokens_python/recipe/emailpassword/interfaces.py +++ b/supertokens_python/recipe/emailpassword/interfaces.py @@ -15,37 +15,59 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Union +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.emailpassword.types import EmailTemplateVars from ...supertokens import AppInfo - -from ...types import APIResponse, GeneralErrorResponse +from ...types import ( + APIResponse, + GeneralErrorResponse, + RecipeUserId, +) if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer - from .types import FormField, User + from .types import FormField + from ...types import User from .utils import EmailPasswordConfig class SignUpOkResult: - def __init__(self, user: User): + status: str = "OK" + + def __init__(self, user: User, recipe_user_id: RecipeUserId): self.user = user + self.recipe_user_id = recipe_user_id + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "user": self.user.to_json(), + "recipeUserId": self.recipe_user_id.get_as_string(), + } -class SignUpEmailAlreadyExistsError: - pass + +class EmailAlreadyExistsError(APIResponse): + status: str = "EMAIL_ALREADY_EXISTS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} class SignInOkResult: - def __init__(self, user: User): + def __init__(self, user: User, recipe_user_id: RecipeUserId): self.user = user + self.recipe_user_id = recipe_user_id -class SignInWrongCredentialsError: - pass +class WrongCredentialsError(APIResponse): + status: str = "WRONG_CREDENTIALS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} class CreateResetPasswordOkResult: @@ -53,105 +75,124 @@ def __init__(self, token: str): self.token = token -class CreateResetPasswordWrongUserIdError: - pass - - -class CreateResetPasswordLinkOkResult: - def __init__(self, link: str): - self.link = link - - -class CreateResetPasswordLinkUnknownUserIdError: - pass - - -class SendResetPasswordEmailOkResult: - pass - - -class SendResetPasswordEmailUnknownUserIdError: - pass +class ConsumePasswordResetTokenOkResult: + def __init__(self, email: str, user_id: str): + self.email = email + self.user_id = user_id + def to_json(self) -> Dict[str, Any]: + return { + "email": self.email, + "userId": self.user_id, + } -class ResetPasswordUsingTokenOkResult: - def __init__(self, user_id: Union[str, None]): - self.user_id = user_id +class PasswordResetTokenInvalidError(APIResponse): + status: str = "RESET_PASSWORD_INVALID_TOKEN_ERROR" -class ResetPasswordUsingTokenInvalidTokenError: - pass + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} class UpdateEmailOrPasswordOkResult: pass -class UpdateEmailOrPasswordEmailAlreadyExistsError: +class UnknownUserIdError: pass -class UpdateEmailOrPasswordUnknownUserIdError: - pass - +class UpdateEmailOrPasswordEmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason -class UpdateEmailOrPasswordPasswordPolicyViolationError: - failure_reason: str +class PasswordPolicyViolationError(APIResponse): def __init__(self, failure_reason: str): self.failure_reason = failure_reason + def to_json(self) -> Dict[str, Any]: + return { + "status": "PASSWORD_POLICY_VIOLATED_ERROR", + "failureReason": self.failure_reason, + } + class RecipeInterface(ABC): def __init__(self): pass @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def sign_up( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], + ) -> Union[ + SignUpOkResult, + EmailAlreadyExistsError, + LinkingToSessionUserFailedError, + ]: pass @abstractmethod - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def create_new_recipe_user( + self, + email: str, + password: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[SignUpOkResult, EmailAlreadyExistsError]: pass @abstractmethod - async def create_reset_password_token( - self, user_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError]: + async def sign_in( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError,]: pass @abstractmethod - async def reset_password_using_token( + async def verify_credentials( self, - token: str, - new_password: str, + email: str, + password: str, tenant_id: str, user_context: Dict[str, Any], - ) -> Union[ - ResetPasswordUsingTokenOkResult, ResetPasswordUsingTokenInvalidTokenError - ]: + ) -> Union[SignInOkResult, WrongCredentialsError]: pass @abstractmethod - async def sign_in( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignInOkResult, SignInWrongCredentialsError]: + async def create_reset_password_token( + self, + user_id: str, + email: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: pass @abstractmethod - async def sign_up( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignUpOkResult, SignUpEmailAlreadyExistsError]: + async def consume_password_reset_token( + self, + token: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: pass @abstractmethod async def update_email_or_password( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], password: Union[str, None], apply_password_policy: Union[bool, None], @@ -159,9 +200,10 @@ async def update_email_or_password( user_context: Dict[str, Any], ) -> Union[ UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, ]: pass @@ -203,18 +245,22 @@ def to_json(self) -> Dict[str, Any]: return {"status": self.status} -class PasswordResetPostOkResult(APIResponse): - status: str = "OK" +class GeneratePasswordResetTokenPostNotAllowedResponse(APIResponse): + status: str = "PASSWORD_RESET_NOT_ALLOWED" - def __init__(self, user_id: Union[str, None]): - self.user_id = user_id + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} -class PasswordResetPostInvalidTokenResponse(APIResponse): - status: str = "RESET_PASSWORD_INVALID_TOKEN_ERROR" +class PasswordResetPostOkResult(APIResponse): + status: str = "OK" + + def __init__(self, email: str, user: User): + self.email = email + self.user = user def to_json(self) -> Dict[str, Any]: return {"status": self.status} @@ -230,19 +276,18 @@ def __init__(self, user: User, session: SessionContainer): def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - }, + "user": self.user.to_json(), } -class SignInPostWrongCredentialsError(APIResponse): - status: str = "WRONG_CREDENTIALS_ERROR" +class SignInPostNotAllowedResponse(APIResponse): + status: str = "SIGN_IN_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} class SignUpPostOkResult(APIResponse): @@ -253,21 +298,17 @@ def __init__(self, user: User, session: SessionContainer): self.session = session def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - }, - } + return {"status": self.status, "user": self.user.to_json()} -class SignUpPostEmailAlreadyExistsError(APIResponse): - status: str = "EMAIL_ALREADY_EXISTS_ERROR" +class SignUpPostNotAllowedResponse(APIResponse): + status: str = "SIGN_UP_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason def to_json(self) -> Dict[str, Any]: - return {"status": self.status} + return {"status": self.status, "reason": self.reason} class APIInterface: @@ -295,7 +336,11 @@ async def generate_password_reset_token_post( tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[GeneratePasswordResetTokenPostOkResult, GeneralErrorResponse]: + ) -> Union[ + GeneratePasswordResetTokenPostOkResult, + GeneratePasswordResetTokenPostNotAllowedResponse, + GeneralErrorResponse, + ]: pass @abstractmethod @@ -308,7 +353,8 @@ async def password_reset_post( user_context: Dict[str, Any], ) -> Union[ PasswordResetPostOkResult, - PasswordResetPostInvalidTokenResponse, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, GeneralErrorResponse, ]: pass @@ -318,10 +364,15 @@ async def sign_in_post( self, form_fields: List[FormField], tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignInPostOkResult, SignInPostWrongCredentialsError, GeneralErrorResponse + SignInPostOkResult, + WrongCredentialsError, + SignInPostNotAllowedResponse, + GeneralErrorResponse, ]: pass @@ -330,9 +381,14 @@ async def sign_up_post( self, form_fields: List[FormField], tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ - SignUpPostOkResult, SignUpPostEmailAlreadyExistsError, GeneralErrorResponse + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, ]: pass diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index f4b681eac..7ec8469e2 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -15,6 +15,7 @@ from os import environ from typing import TYPE_CHECKING, Any, Dict, List, Union +from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig @@ -23,12 +24,19 @@ EmailPasswordIngredients, EmailTemplateVars, ) -from supertokens_python.recipe_module import APIHandled, RecipeModule -from ..emailverification.interfaces import ( - UnknownUserIdError, - GetEmailForUserIdOkResult, - EmailDoesNotExistError, +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import ( + FactorIds, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetFactorsSetupForUserFromOtherRecipesFunc, ) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.types import User, RecipeUserId from .api.implementation import APIImplementation from .exceptions import FieldError, SuperTokensEmailPasswordError @@ -43,7 +51,6 @@ from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.querier import Querier -from supertokens_python.recipe.emailverification import EmailVerificationRecipe from .api import ( handle_email_exists_api, @@ -64,7 +71,6 @@ InputOverrideConfig, InputSignUpFeature, validate_and_normalise_user_input, - EmailPasswordConfig, ) @@ -90,11 +96,8 @@ def __init__( email_delivery, ) - def get_emailpassword_config() -> EmailPasswordConfig: - return self.config - recipe_implementation = RecipeImplementation( - Querier.get_instance(recipe_id), get_emailpassword_config + Querier.get_instance(recipe_id), self.config ) self.recipe_implementation = ( recipe_implementation @@ -118,9 +121,145 @@ def get_emailpassword_config() -> EmailPasswordConfig: ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return ["emailpassword"] + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + + async def get_factors_setup_for_user( + user: User, _: Dict[str, Any] + ) -> List[str]: + for login_method in user.login_methods: + # We don't check for tenant_id here because if we find the user + # with emailpassword login_method from different tenant, then + # we assume the factor is setup for this user. And as part of factor + # completion, we associate that login_method with the session's tenant_id + if login_method.recipe_id == EmailPasswordRecipe.recipe_id: + return ["emailpassword"] + return [] + + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc( + get_factors_setup_for_user + ) + ) + + async def get_emails_for_factor( + user: User, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ]: + # This function is called in the MFA info endpoint API. + # Based on https://github.com/supertokens/supertokens-node/pull/741#discussion_r1432749346 + + # preparing some reusable variables for the logic below... + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() + + # We order the login methods based on time_joined (oldest first) + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + # Then we take the ones that belong to this recipe + recipe_login_methods = [ + lm + for lm in ordered_login_methods + if lm.recipe_id == EmailPasswordRecipe.recipe_id + ] + + if recipe_login_methods: + # If there are login methods belonging to this recipe, the factor is set up + # In this case we only list email addresses that have a password associated with them + result = ( + # First we take the verified real emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email + and not is_fake_email(lm.email) + and lm.verified + ] + + + # Then we take the non-verified real emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email + and not is_fake_email(lm.email) + and not lm.verified + ] + + + # Lastly, fake emails associated with emailpassword login methods ordered by time_joined (oldest first) + [ + lm.email + for lm in recipe_login_methods + if lm.email and is_fake_email(lm.email) + ] + ) + else: + # This factor hasn't been set up, we list all emails belonging to the user + if any( + lm.email and not is_fake_email(lm.email) + for lm in ordered_login_methods + ): + # If there is at least one real email address linked to the user, we only suggest real addresses + result = [ + lm.email + for lm in ordered_login_methods + if lm.email and not is_fake_email(lm.email) + ] + else: + # Else we use the fake ones + result = [ + lm.email + for lm in ordered_login_methods + if lm.email and is_fake_email(lm.email) + ] + + # Since in this case emails are not guaranteed to be unique, we de-duplicate the results, keeping the oldest one in the list. + result = list(dict.fromkeys(result)) + + # If the login_method associated with the session has an email address, we move it to the top of the list (if it's already in the list) + if ( + session_login_method.email + and session_login_method.email in result + ): + result.remove(session_login_method.email) + result.insert(0, session_login_method.email) + + # If the list is empty we generate an email address to make the flow where the user is never asked for + # an email address easier to implement. + if not result: + result.append( + f"{session_recipe_user_id}@stfakeemail.supertokens.com" + ) + + return GetEmailsForFactorOkResult( + factor_id_to_emails_map={"emailpassword": result} + ) + + mfa_instance.add_func_to_get_emails_for_factor_from_other_recipes( + GetEmailsForFactorFromOtherRecipesFunc(get_emails_for_factor) + ) + + mt_recipe = MultitenancyRecipe.get_instance_optional() + if mt_recipe is not None: + mt_recipe.all_available_first_factors.append(FactorIds.EMAILPASSWORD) PostSTInitCallbacks.add_post_init_callback(callback) @@ -269,16 +408,3 @@ def reset(): ): raise_general_exception("calling testing function in non testing env") EmailPasswordRecipe.__instance = None - - # instance functions below............... - - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[UnknownUserIdError, GetEmailForUserIdOkResult, EmailDoesNotExistError]: - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - return GetEmailForUserIdOkResult(user_info.email) - - return UnknownUserIdError() diff --git a/supertokens_python/recipe/emailpassword/recipe_implementation.py b/supertokens_python/recipe/emailpassword/recipe_implementation.py index 3e59fd309..283569a5f 100644 --- a/supertokens_python/recipe/emailpassword/recipe_implementation.py +++ b/supertokens_python/recipe/emailpassword/recipe_implementation.py @@ -13,28 +13,36 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Union, Callable +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.asyncio import get_user from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import RecipeUserId from .interfaces import ( CreateResetPasswordOkResult, - CreateResetPasswordWrongUserIdError, + UnknownUserIdError, RecipeInterface, - ResetPasswordUsingTokenOkResult, - ResetPasswordUsingTokenInvalidTokenError, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, SignInOkResult, - SignInWrongCredentialsError, - SignUpEmailAlreadyExistsError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + WrongCredentialsError, + EmailAlreadyExistsError, SignUpOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + PasswordPolicyViolationError, ) -from .types import User from .utils import EmailPasswordConfig from .constants import FORM_FIELD_PASSWORD_ID +from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, +) +from ...types import User if TYPE_CHECKING: from supertokens_python.querier import Querier @@ -44,122 +52,187 @@ class RecipeImplementation(RecipeInterface): def __init__( self, querier: Querier, - get_emailpassword_config: Callable[[], EmailPasswordConfig], + ep_config: EmailPasswordConfig, ): super().__init__() self.querier = querier - self.get_emailpassword_config = get_emailpassword_config - - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"userId": user_id} - response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), params, user_context - ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - return None - - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"email": email} - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), params, user_context + self.ep_config = ep_config + + async def sign_up( + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], + ) -> Union[ + SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError + ]: + response = await self.create_new_recipe_user( + email=email, + password=password, + tenant_id=tenant_id, + user_context=user_context, ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - return None + if isinstance(response, EmailAlreadyExistsError): + return response - async def create_reset_password_token( - self, user_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[CreateResetPasswordOkResult, CreateResetPasswordWrongUserIdError]: - data = {"userId": user_id} - response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token"), - data, + updated_user = response.user + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=response.user, + recipe_user_id=response.recipe_user_id, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) - if "status" in response and response["status"] == "OK": - return CreateResetPasswordOkResult(response["token"]) - return CreateResetPasswordWrongUserIdError() - async def reset_password_using_token( + if isinstance(link_result, LinkingToSessionUserFailedError): + return LinkingToSessionUserFailedError(reason=link_result.reason) + + updated_user = link_result.user + + return SignUpOkResult( + user=updated_user, + recipe_user_id=response.recipe_user_id, + ) + + async def create_new_recipe_user( self, - token: str, - new_password: str, + email: str, + password: str, tenant_id: str, user_context: Dict[str, Any], - ) -> Union[ - ResetPasswordUsingTokenOkResult, ResetPasswordUsingTokenInvalidTokenError - ]: - data = {"method": "token", "token": token, "newPassword": new_password} + ) -> Union[SignUpOkResult, EmailAlreadyExistsError]: response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset"), - data, + NormalisedURLPath(f"{tenant_id}/recipe/signup"), + { + "email": email, + "password": password, + }, user_context=user_context, ) - if "status" not in response or response["status"] != "OK": - return ResetPasswordUsingTokenInvalidTokenError() - user_id = None - if "userId" in response: - user_id = response["userId"] - return ResetPasswordUsingTokenOkResult(user_id) + if response["status"] == "OK": + return SignUpOkResult( + user=User.from_json(response["user"]), + recipe_user_id=RecipeUserId(response["recipeUserId"]), + ) + return EmailAlreadyExistsError() async def sign_in( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignInOkResult, SignInWrongCredentialsError]: - data = {"password": password, "email": email} + self, + email: str, + password: str, + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: + response = await self.verify_credentials( + email, password, tenant_id, user_context + ) + + if isinstance(response, SignInOkResult): + login_method = next( + ( + lm + for lm in response.user.login_methods + if lm.recipe_user_id.get_as_string() + == response.recipe_user_id.get_as_string() + ), + None, + ) + + assert login_method is not None + + if not login_method.verified: + await AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=response.user, + recipe_user_id=response.recipe_user_id, + user_context=user_context, + ) + + # We do this to get the updated user (in case the above function updated the verification status) + updated_user = await get_user( + response.recipe_user_id.get_as_string(), user_context + ) + assert updated_user is not None + response.user = updated_user + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=response.user, + recipe_user_id=response.recipe_user_id, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, + user_context=user_context, + ) + + if isinstance(link_result, LinkingToSessionUserFailedError): + return link_result + + response.user = link_result.user + + return response + + async def verify_credentials( + self, + email: str, + password: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[SignInOkResult, WrongCredentialsError]: response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signin"), - data, + { + "email": email, + "password": password, + }, user_context=user_context, ) - if "status" in response and response["status"] == "OK": + + if response["status"] == "OK": return SignInOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) + user=User.from_json(response["user"]), + recipe_user_id=RecipeUserId(response["recipeUserId"]), ) - return SignInWrongCredentialsError() - async def sign_up( - self, email: str, password: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[SignUpOkResult, SignUpEmailAlreadyExistsError]: - data = {"password": password, "email": email} + return WrongCredentialsError() + + async def create_reset_password_token( + self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + ) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: + data = {"userId": user_id, "email": email} response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signup"), + NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token"), data, user_context=user_context, ) if "status" in response and response["status"] == "OK": - return SignUpOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ) - ) - return SignUpEmailAlreadyExistsError() + return CreateResetPasswordOkResult(response["token"]) + return UnknownUserIdError() + + async def consume_password_reset_token( + self, + token: str, + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: + data = {"token": token} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/user/password/reset/token/consume"), + data, + user_context=user_context, + ) + if "status" not in response or response["status"] != "OK": + return PasswordResetTokenInvalidError() + return ConsumePasswordResetTokenOkResult(response["email"], response["userId"]) async def update_email_or_password( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], password: Union[str, None], apply_password_policy: Union[bool, None], @@ -167,34 +240,77 @@ async def update_email_or_password( user_context: Dict[str, Any], ) -> Union[ UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, - UpdateEmailOrPasswordUnknownUserIdError, - UpdateEmailOrPasswordPasswordPolicyViolationError, + EmailAlreadyExistsError, + UnknownUserIdError, + PasswordPolicyViolationError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, ]: - data = {"userId": user_id} + account_linking = AccountLinkingRecipe.get_instance() + data = {"recipeUserId": recipe_user_id.get_as_string()} + if email is not None: - data = {"email": email, **data} + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UnknownUserIdError() + + ev_instance = EmailVerificationRecipe.get_instance_optional() + is_email_verified = False + if ev_instance: + is_email_verified = ( + await ev_instance.recipe_implementation.is_email_verified( + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, + ) + ) + + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_email_verified, + new_email=email, + session=None, + user_context=user_context, + ) + if not is_email_change_allowed.allowed: + reason = ( + "New email cannot be applied to existing account because of account takeover risks." + if is_email_change_allowed.reason == "ACCOUNT_TAKEOVER_RISK" + else "New email cannot be applied to existing account because of there is another primary user with the same email address." + ) + return UpdateEmailOrPasswordEmailChangeNotAllowedError(reason) + + data["email"] = email + if password is not None: if apply_password_policy is None or apply_password_policy: - form_fields = ( - self.get_emailpassword_config().sign_up_feature.form_fields + form_fields = self.ep_config.sign_up_feature.form_fields + password_field = next( + field for field in form_fields if field.id == FORM_FIELD_PASSWORD_ID ) - password_field = list( - filter(lambda x: x.id == FORM_FIELD_PASSWORD_ID, form_fields) - )[0] error = await password_field.validate( password, tenant_id_for_password_policy ) if error is not None: - return UpdateEmailOrPasswordPasswordPolicyViolationError(error) - data = {"password": password, **data} + return PasswordPolicyViolationError(error) + data["password"] = password + response = await self.querier.send_put_request( NormalisedURLPath("/recipe/user"), data, user_context=user_context, ) - if "status" in response and response["status"] == "OK": + + if response.get("status") == "OK": + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UnknownUserIdError() + await AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) return UpdateEmailOrPasswordOkResult() - if "status" in response and response["status"] == "EMAIL_ALREADY_EXISTS_ERROR": - return UpdateEmailOrPasswordEmailAlreadyExistsError() - return UpdateEmailOrPasswordUnknownUserIdError() + elif response.get("status") == "EMAIL_ALREADY_EXISTS_ERROR": + return EmailAlreadyExistsError() + else: + return UnknownUserIdError() diff --git a/supertokens_python/recipe/emailpassword/syncio/__init__.py b/supertokens_python/recipe/emailpassword/syncio/__init__.py index 8a7cd833f..1681589cb 100644 --- a/supertokens_python/recipe/emailpassword/syncio/__init__.py +++ b/supertokens_python/recipe/emailpassword/syncio/__init__.py @@ -12,130 +12,200 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict, Union, Optional +from typing_extensions import Literal from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.emailpassword.interfaces import ( + SignUpOkResult, + EmailAlreadyExistsError, + LinkingToSessionUserFailedError, + SignInOkResult, + WrongCredentialsError, + CreateResetPasswordOkResult, + ConsumePasswordResetTokenOkResult, + PasswordResetTokenInvalidError, + UpdateEmailOrPasswordOkResult, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +) +from supertokens_python.recipe.emailpassword.types import ( + EmailTemplateVars, +) +from supertokens_python.types import RecipeUserId -from ..interfaces import SignInOkResult, SignInWrongCredentialsError -from ..types import EmailTemplateVars, User +def sign_up( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignUpOkResult, EmailAlreadyExistsError, LinkingToSessionUserFailedError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import sign_up as async_sign_up -def update_email_or_password( - user_id: str, - email: Union[str, None] = None, - password: Union[str, None] = None, - apply_password_policy: Union[bool, None] = None, - tenant_id_for_password_policy: Optional[str] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password - - return sync( - update_email_or_password( - user_id, - email, - password, - apply_password_policy, - tenant_id_for_password_policy, - user_context, - ) - ) + return sync(async_sign_up(tenant_id, email, password, session, user_context)) -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[None, User]: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id +def sign_in( + tenant_id: str, + email: str, + password: str, + session: Optional[SessionContainer] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError, LinkingToSessionUserFailedError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import sign_in as async_sign_in - return sync(get_user_by_id(user_id, user_context)) + return sync(async_sign_in(tenant_id, email, password, session, user_context)) -def get_user_by_email( +def verify_credentials( tenant_id: str, email: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[None, User]: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + password: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[SignInOkResult, WrongCredentialsError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + verify_credentials as async_verify_credentials, + ) - return sync(get_user_by_email(tenant_id, email, user_context)) + return sync(async_verify_credentials(tenant_id, email, password, user_context)) def create_reset_password_token( tenant_id: str, user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateResetPasswordOkResult, UnknownUserIdError]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_token, + create_reset_password_token as async_create_reset_password_token, ) - return sync(create_reset_password_token(tenant_id, user_id, user_context)) + return sync( + async_create_reset_password_token(tenant_id, user_id, email, user_context) + ) def reset_password_using_token( tenant_id: str, token: str, new_password: str, - user_context: Union[None, Dict[str, Any]] = None, -): + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + PasswordPolicyViolationError, + PasswordResetTokenInvalidError, + UnknownUserIdError, +]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - reset_password_using_token, + reset_password_using_token as async_reset_password_using_token, ) return sync( - reset_password_using_token(tenant_id, token, new_password, user_context) + async_reset_password_using_token(tenant_id, token, new_password, user_context) ) -def sign_in( - tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[SignInOkResult, SignInWrongCredentialsError]: - from supertokens_python.recipe.emailpassword.asyncio import sign_in - - return sync(sign_in(tenant_id, email, password, user_context)) - - -def sign_up( +def consume_password_reset_token( tenant_id: str, - email: str, - password: str, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import sign_up + token: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ConsumePasswordResetTokenOkResult, PasswordResetTokenInvalidError]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + consume_password_reset_token as async_consume_password_reset_token, + ) - return sync(sign_up(tenant_id, email, password, user_context)) + return sync(async_consume_password_reset_token(tenant_id, token, user_context)) -def send_email( - input_: EmailTemplateVars, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.emailpassword.asyncio import send_email +def update_email_or_password( + recipe_user_id: RecipeUserId, + email: Optional[str] = None, + password: Optional[str] = None, + apply_password_policy: Optional[bool] = None, + tenant_id_for_password_policy: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + UpdateEmailOrPasswordOkResult, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + PasswordPolicyViolationError, +]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + update_email_or_password as async_update_email_or_password, + ) - return sync(send_email(input_, user_context)) + return sync( + async_update_email_or_password( + recipe_user_id, + email, + password, + apply_password_policy, + tenant_id_for_password_policy, + user_context, + ) + ) def create_reset_password_link( tenant_id: str, user_id: str, + email: str, user_context: Optional[Dict[str, Any]] = None, -): +) -> Union[str, UnknownUserIdError]: + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - create_reset_password_link, + create_reset_password_link as async_create_reset_password_link, ) - return sync(create_reset_password_link(tenant_id, user_id, user_context)) + return sync( + async_create_reset_password_link(tenant_id, user_id, email, user_context) + ) def send_reset_password_email( tenant_id: str, user_id: str, + email: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[Literal["UNKNOWN_USER_ID_ERROR"], Literal["OK"]]: + if user_context is None: + user_context = {} + from supertokens_python.recipe.emailpassword.asyncio import ( + send_reset_password_email as async_send_reset_password_email, + ) + + return sync( + async_send_reset_password_email(tenant_id, user_id, email, user_context) + ) + + +def send_email( + input_: EmailTemplateVars, user_context: Optional[Dict[str, Any]] = None, ): + if user_context is None: + user_context = {} from supertokens_python.recipe.emailpassword.asyncio import ( - send_reset_password_email, + send_email as async_send_email, ) - return sync(send_reset_password_email(tenant_id, user_id, user_context)) + return sync(async_send_email(input_, user_context)) diff --git a/supertokens_python/recipe/emailpassword/types.py b/supertokens_python/recipe/emailpassword/types.py index ef3d4b6f9..890ec14b4 100644 --- a/supertokens_python/recipe/emailpassword/types.py +++ b/supertokens_python/recipe/emailpassword/types.py @@ -12,39 +12,14 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations - -from typing import Any, Awaitable, Callable, List, TypeVar, Union +from typing import Awaitable, Callable, Dict, Optional, TypeVar, Union, Any from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryInterface, SMTPServiceInterface, ) - - -class User: - def __init__( - self, user_id: str, email: str, time_joined: int, tenant_ids: List[str] - ): - self.user_id = user_id - self.email = email - self.time_joined = time_joined - self.tenant_ids = tenant_ids - - def __eq__(self, other: object): - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - ) - - -class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users = users - self.next_pagination_token = next_pagination_token +from supertokens_python.types import RecipeUserId class ErrorFormField: @@ -58,6 +33,9 @@ def __init__(self, id: str, value: Any): # pylint: disable=redefined-builtin self.id: str = id self.value: Any = value + def to_json(self) -> Dict[str, Any]: + return {"id": self.id, "value": self.value} + class InputFormField: def __init__( @@ -90,10 +68,26 @@ def __init__( class PasswordResetEmailTemplateVarsUser: - def __init__(self, user_id: str, email: str): + def __init__( + self, user_id: str, recipe_user_id: Optional[RecipeUserId], email: str + ): self.id = user_id + self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + resp_json = { + "id": self.id, + "recipeUserId": ( + self.recipe_user_id.get_as_string() + if self.recipe_user_id is not None + else None + ), + "email": self.email, + } + # Remove items that are None + return {k: v for k, v in resp_json.items() if v is not None} + class PasswordResetEmailTemplateVars: def __init__( @@ -106,6 +100,14 @@ def __init__( self.password_reset_link = password_reset_link self.tenant_id = tenant_id + def to_json(self) -> Dict[str, Any]: + return { + "type": "PASSWORD_RESET", + "user": self.user.to_json(), + "passwordResetLink": self.password_reset_link, + "tenantId": self.tenant_id, + } + # Export: EmailTemplateVars = PasswordResetEmailTemplateVars diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index 78d2134cc..ba1f67fc3 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -43,12 +43,12 @@ def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> Callable[[AppInfo], RecipeModule]: return EmailVerificationRecipe.init( mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) diff --git a/supertokens_python/recipe/emailverification/asyncio/__init__.py b/supertokens_python/recipe/emailverification/asyncio/__init__.py index c15dc121d..c2f199067 100644 --- a/supertokens_python/recipe/emailverification/asyncio/__init__.py +++ b/supertokens_python/recipe/emailverification/asyncio/__init__.py @@ -28,7 +28,7 @@ ) from supertokens_python.recipe.emailverification.types import EmailTemplateVars from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe - +from supertokens_python.types import RecipeUserId from supertokens_python.recipe.emailverification.utils import get_email_verify_link from supertokens_python.recipe.emailverification.types import ( VerificationEmailTemplateVars, @@ -38,7 +38,7 @@ async def create_email_verification_token( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ @@ -47,9 +47,11 @@ async def create_email_verification_token( ]: if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -58,31 +60,36 @@ async def create_email_verification_token( raise Exception("Unknown User ID provided without email") return await recipe.recipe_implementation.create_email_verification_token( - user_id, email, tenant_id, user_context + recipe_user_id, email, tenant_id, user_context ) async def verify_email_using_token( - tenant_id: str, token: str, user_context: Union[None, Dict[str, Any]] = None + tenant_id: str, + token: str, + attempt_account_linking: bool = True, + user_context: Union[None, Dict[str, Any]] = None, ): if user_context is None: user_context = {} - return await EmailVerificationRecipe.get_instance().recipe_implementation.verify_email_using_token( - token, tenant_id, user_context + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.verify_email_using_token( + token, tenant_id, attempt_account_linking, user_context ) async def is_email_verified( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -91,22 +98,24 @@ async def is_email_verified( raise Exception("Unknown User ID provided without email") return await recipe.recipe_implementation.is_email_verified( - user_id, email, user_context + recipe_user_id, email, user_context ) async def revoke_email_verification_tokens( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ) -> RevokeEmailVerificationTokensOkResult: if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -114,22 +123,24 @@ async def revoke_email_verification_tokens( else: raise Exception("Unknown User ID provided without email") - return await EmailVerificationRecipe.get_instance().recipe_implementation.revoke_email_verification_tokens( - user_id, email, tenant_id, user_context + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.revoke_email_verification_tokens( + recipe_user_id, email, tenant_id, user_context ) async def unverify_email( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): if user_context is None: user_context = {} - recipe = EmailVerificationRecipe.get_instance() + recipe = EmailVerificationRecipe.get_instance_or_throw() if email is None: - email_info = await recipe.get_email_for_user_id(user_id, user_context) + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -139,8 +150,8 @@ async def unverify_email( else: raise Exception("Unknown User ID provided without email") - return await EmailVerificationRecipe.get_instance().recipe_implementation.unverify_email( - user_id, email, user_context + return await EmailVerificationRecipe.get_instance_or_throw().recipe_implementation.unverify_email( + recipe_user_id, email, user_context ) @@ -150,14 +161,14 @@ async def send_email( ): if user_context is None: user_context = {} - return await EmailVerificationRecipe.get_instance().email_delivery.ingredient_interface_impl.send_email( + return await EmailVerificationRecipe.get_instance_or_throw().email_delivery.ingredient_interface_impl.send_email( input_, user_context ) async def create_email_verification_link( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ) -> Union[ @@ -167,11 +178,11 @@ async def create_email_verification_link( if user_context is None: user_context = {} - recipe_instance = EmailVerificationRecipe.get_instance() + recipe_instance = EmailVerificationRecipe.get_instance_or_throw() app_info = recipe_instance.get_app_info() email_verification_token = await create_email_verification_token( - tenant_id, user_id, email, user_context + tenant_id, recipe_user_id, email, user_context ) if isinstance( email_verification_token, CreateEmailVerificationTokenEmailAlreadyVerifiedError @@ -193,6 +204,7 @@ async def create_email_verification_link( async def send_email_verification_email( tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ) -> Union[ @@ -203,9 +215,11 @@ async def send_email_verification_email( user_context = {} if email is None: - recipe_instance = EmailVerificationRecipe.get_instance() + recipe_instance = EmailVerificationRecipe.get_instance_or_throw() - email_info = await recipe_instance.get_email_for_user_id(user_id, user_context) + email_info = await recipe_instance.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): @@ -214,7 +228,7 @@ async def send_email_verification_email( raise Exception("Unknown User ID provided without email") email_verification_link = await create_email_verification_link( - tenant_id, user_id, email, user_context + tenant_id, recipe_user_id, email, user_context ) if isinstance( @@ -224,7 +238,7 @@ async def send_email_verification_email( await send_email( VerificationEmailTemplateVars( - user=VerificationEmailTemplateVarsUser(user_id, email), + user=VerificationEmailTemplateVarsUser(user_id, recipe_user_id, email), email_verify_link=email_verification_link.link, tenant_id=tenant_id, ), diff --git a/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py b/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py index 128cef084..20d82c4ce 100644 --- a/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py +++ b/supertokens_python/recipe/emailverification/emaildelivery/services/backward_compatibility/__init__.py @@ -19,7 +19,7 @@ from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryInterface from supertokens_python.logger import log_debug_message from supertokens_python.recipe.emailverification.types import ( - User, + EmailVerificationUser, VerificationEmailTemplateVars, ) from supertokens_python.supertokens import AppInfo @@ -27,7 +27,7 @@ async def create_and_send_email_using_supertokens_service( - app_info: AppInfo, user: User, email_verification_url: str + app_info: AppInfo, user: EmailVerificationUser, email_verification_url: str ) -> None: if ("SUPERTOKENS_ENV" in environ) and (environ["SUPERTOKENS_ENV"] == "testing"): return @@ -62,7 +62,9 @@ async def send_email( user_context: Dict[str, Any], ) -> None: try: - email_user = User(template_vars.user.id, template_vars.user.email) + email_user = EmailVerificationUser( + template_vars.user.recipe_user_id, template_vars.user.email + ) await create_and_send_email_using_supertokens_service( self.app_info, email_user, template_vars.email_verify_link ) diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index bfa0b3e09..b6d07c889 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -17,54 +17,58 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import APIResponse, GeneralErrorResponse, RecipeUserId from ...supertokens import AppInfo from ..session.interfaces import SessionContainer +from typing_extensions import Literal if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse - from .types import User, VerificationEmailTemplateVars + from .types import EmailVerificationUser, VerificationEmailTemplateVars from .utils import EmailVerificationConfig class CreateEmailVerificationTokenOkResult: - status = "OK" + status: Literal["OK"] = "OK" def __init__(self, token: str): self.token = token class CreateEmailVerificationTokenEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class CreateEmailVerificationLinkEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class CreateEmailVerificationLinkOkResult: - status = "OK" + status: Literal["OK"] = "OK" def __init__(self, link: str): self.link = link class SendEmailVerificationEmailAlreadyVerifiedError: - status = "EMAIL_ALREADY_VERIFIED_ERROR" + status: Literal["EMAIL_ALREADY_VERIFIED_ERROR"] = "EMAIL_ALREADY_VERIFIED_ERROR" class SendEmailVerificationEmailOkResult: - status = "OK" + status: Literal["OK"] = "OK" class VerifyEmailUsingTokenOkResult: - status = "OK" + status: Literal["OK"] = "OK" - def __init__(self, user: User): + def __init__(self, user: EmailVerificationUser): self.user = user + def to_json(self) -> Dict[str, Any]: + return {"user": self.user.to_json(), "status": self.status} + class VerifyEmailUsingTokenInvalidTokenError: pass @@ -84,7 +88,11 @@ def __init__(self): @abstractmethod async def create_email_verification_token( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> Union[ CreateEmailVerificationTokenOkResult, CreateEmailVerificationTokenEmailAlreadyVerifiedError, @@ -93,25 +101,33 @@ async def create_email_verification_token( @abstractmethod async def verify_email_using_token( - self, token: str, tenant_id: str, user_context: Dict[str, Any] + self, + token: str, + tenant_id: str, + attempt_account_linking: bool, + user_context: Dict[str, Any], ) -> Union[VerifyEmailUsingTokenOkResult, VerifyEmailUsingTokenInvalidTokenError]: pass @abstractmethod async def is_email_verified( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> bool: pass @abstractmethod async def revoke_email_verification_tokens( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> RevokeEmailVerificationTokensOkResult: pass @abstractmethod async def unverify_email( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> UnverifyEmailOkResult: pass @@ -137,15 +153,15 @@ def __init__( class EmailVerifyPostOkResult(APIResponse): - def __init__(self, user: User): + def __init__( + self, user: EmailVerificationUser, new_session: Optional[SessionContainer] + ): self.user = user + self.new_session = new_session self.status = "OK" def to_json(self) -> Dict[str, Any]: - return { - "status": self.status, - "user": {"id": self.user.user_id, "email": self.user.email}, - } + return {"status": self.status} class EmailVerifyPostInvalidTokenError(APIResponse): @@ -157,9 +173,10 @@ def to_json(self) -> Dict[str, Any]: class IsEmailVerifiedGetOkResult(APIResponse): - def __init__(self, is_verified: bool): + def __init__(self, is_verified: bool, new_session: Optional[SessionContainer]): self.status = "OK" self.is_verified = is_verified + self.new_session = new_session def to_json(self) -> Dict[str, Any]: return {"status": self.status, "isVerified": self.is_verified} @@ -174,8 +191,9 @@ def to_json(self) -> Dict[str, Any]: class GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError(APIResponse): - def __init__(self): + def __init__(self, new_session: Optional[SessionContainer]): self.status = "EMAIL_ALREADY_VERIFIED_ERROR" + self.new_session = new_session def to_json(self) -> Dict[str, Any]: return {"status": self.status} @@ -237,7 +255,7 @@ class UnknownUserIdError(Exception): TypeGetEmailForUserIdFunction = Callable[ - [str, Dict[str, Any]], + [RecipeUserId, Dict[str, Any]], Awaitable[ Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError] ], diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index e02dd06ef..0fc977b5b 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -32,9 +32,9 @@ from ...ingredients.emaildelivery.types import EmailDeliveryConfig from ...logger import log_debug_message from ...post_init_callbacks import PostSTInitCallbacks -from ...types import MaybeAwaitable from ...utils import get_timestamp_ms from ..session import SessionRecipe +from ..session.asyncio import revoke_all_sessions_for_user, create_new_session from ..session.claim_base_classes.boolean_claim import ( BooleanClaim, BooleanClaimValidators, @@ -67,6 +67,8 @@ from supertokens_python.framework.request import BaseRequest from supertokens_python.framework.response import BaseResponse from supertokens_python.supertokens import AppInfo + from supertokens_python.types import RecipeUserId + from ...types import User, MaybeAwaitable from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.querier import Querier @@ -90,7 +92,7 @@ def __init__( ingredients: EmailVerificationIngredients, mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> None: super().__init__(recipe_id, app_info) @@ -98,12 +100,13 @@ def __init__( app_info, mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) recipe_implementation = RecipeImplementation( - Querier.get_instance(recipe_id), self.config + Querier.get_instance(recipe_id), + self.get_email_for_recipe_user_id, ) self.recipe_implementation = ( recipe_implementation @@ -126,10 +129,6 @@ def __init__( else: self.email_delivery = email_delivery_ingredient - self.get_email_for_user_id_funcs_from_other_recipes: List[ - TypeGetEmailForUserIdFunction - ] = [] - def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and isinstance( err, SuperTokensEmailVerificationError @@ -206,7 +205,7 @@ def get_all_cors_headers(self) -> List[str]: def init( mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ): def func(app_info: AppInfo) -> EmailVerificationRecipe: @@ -218,7 +217,7 @@ def func(app_info: AppInfo) -> EmailVerificationRecipe: ingredients, mode, email_delivery, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) @@ -231,6 +230,15 @@ def callback(): EmailVerificationClaim.validators.is_verified() ) + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + assert EmailVerificationRecipe.__instance is not None + AccountLinkingRecipe.get_instance().register_email_verification_recipe( + EmailVerificationRecipe.__instance + ) + PostSTInitCallbacks.add_post_init_callback(callback) return EmailVerificationRecipe.__instance @@ -241,7 +249,7 @@ def callback(): return func @staticmethod - def get_instance() -> EmailVerificationRecipe: + def get_instance_or_throw() -> EmailVerificationRecipe: if EmailVerificationRecipe.__instance is not None: return EmailVerificationRecipe.__instance raise_general_exception( @@ -260,23 +268,186 @@ def reset(): raise_general_exception("calling testing function in non testing env") EmailVerificationRecipe.__instance = None - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] + async def get_email_for_recipe_user_id( + self, + user: Optional[User], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError]: - if self.config.get_email_for_user_id is not None: - res = await self.config.get_email_for_user_id(user_id, user_context) - if not isinstance(res, UnknownUserIdError): - return res + if self.config.get_email_for_recipe_user_id is not None: + user_res = await self.config.get_email_for_recipe_user_id( + recipe_user_id, user_context + ) + if not isinstance(user_res, UnknownUserIdError): + return user_res + + if user is None: + from supertokens_python.recipe.accountlinking.recipe import ( + AccountLinkingRecipe, + ) + + user = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context + ) - for f in self.get_email_for_user_id_funcs_from_other_recipes: - res = await f(user_id, user_context) - if not isinstance(res, UnknownUserIdError): - return res + if user is None: + return UnknownUserIdError() + + for login_method in user.login_methods: + if ( + login_method.recipe_user_id.get_as_string() + == recipe_user_id.get_as_string() + ): + if login_method.email is not None: + return GetEmailForUserIdOkResult(email=login_method.email) + else: + return EmailDoesNotExistError() return UnknownUserIdError() - def add_get_email_for_user_id_func(self, f: TypeGetEmailForUserIdFunction): - self.get_email_for_user_id_funcs_from_other_recipes.append(f) + async def get_primary_user_id_for_recipe_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> str: + # We extract this into its own function like this cause we want to make sure that + # this recipe does not get the email of the user ID from the getUser function. + # In fact, there is a test "email verification recipe uses getUser function only in getEmailForRecipeUserId" + # which makes sure that this function is only called in 3 places in this recipe: + # - this function + # - getEmailForRecipeUserId function (above) + # - after verification to get the updated user in verifyEmailUsingToken + # We want to isolate the result of calling this function as much as possible + # so that the consumer of the getUser function does not read the email + # from the primaryUser. Hence, this function only returns the string ID + # and nothing else from the primaryUser. + from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe + + primary_user = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + recipe_user_id.get_as_string(), user_context + ) + ) + if primary_user is None: + # This can come here if the user is using session + email verification + # recipe with a user ID that is not known to supertokens. In this case, + # we do not allow linking for such users. + return recipe_user_id.get_as_string() + return primary_user.id + + async def update_session_if_required_post_email_verification( + self, + req: BaseRequest, + session: Optional[SessionContainer], + recipe_user_id_whose_email_got_verified: RecipeUserId, + user_context: Dict[str, Any], + ) -> Optional[SessionContainer]: + primary_user_id = await self.get_primary_user_id_for_recipe_user( + recipe_user_id_whose_email_got_verified, user_context + ) + + # if a session exists in the API, then we can update the session + # claim related to email verification + if session is not None: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification got session" + ) + # Due to linking, we will have to correct the current + # session's user ID. There are four cases here: + # --> (Case 1) User signed up and did email verification and the new account + # became a primary user (user ID no change) + # --> (Case 2) User signed up and did email verification and the new account got linked + # to another primary user (user ID change) + # --> (Case 3) This is post login account linking, in which the account that got verified + # got linked to the session's account (user ID of account has changed to the session's user ID) + # --> (Case 4) This is post login account linking, in which the account that got verified + # got linked to ANOTHER primary account (user ID of account has changed to a different user ID != session.getUserId, but + # we should ignore this since it will result in the user's session changing.) + + if ( + session.get_recipe_user_id(user_context).get_as_string() + == recipe_user_id_whose_email_got_verified.get_as_string() + ): + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session belongs to the verified user" + ) + # this means that the session's login method's account is the + # one that just got verified and that we are NOT doing post login + # account linking. So this is only for (Case 1) and (Case 2) + + if session.get_user_id() == primary_user_id: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session userId matches the primary user id, so we are only refreshing the claim" + ) + # if the session's primary user ID is equal to the + # primary user ID that the account was linked to, then + # this means that the new account became a primary user (Case 1) + # We also have the sub cases here that the account that just + # got verified was already linked to the session's primary user ID, + # but either way, we don't need to change any user ID. + + # In this case, all we do is to update the emailverification claim + try: + # EmailVerificationClaim will be based on the recipeUserId + # and not the primary user ID. + await session.fetch_and_set_claim( + EmailVerificationClaim, user_context + ) + except Exception as err: + # This should never happen, since we've just set the status above. + if str(err) == "UNKNOWN_USER_ID": + raise_unauthorised_exception("Unknown User ID provided") + raise err + + return None + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the session user id doesn't match the primary user id, so we are revoking all sessions and creating a new one" + ) + # if the session's primary user ID is NOT equal to the + # primary user ID that the account that it was linked to, then + # this means that the new account got linked to another primary user (Case 2) + + # In this case, we need to update the session's user ID by creating + # a new session + + # Revoke all session belonging to session.getRecipeUserId() + # We do not really need to do this, but we do it anyway.. no harm. + await revoke_all_sessions_for_user( + recipe_user_id_whose_email_got_verified.get_as_string(), + False, + None, + user_context, + ) + + # create a new session and return that.. + return await create_new_session( + req, + session.get_tenant_id(), + session.get_recipe_user_id(user_context), + {}, + {}, + user_context, + ) + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification the verified user doesn't match the session" + ) + # this means that the session's login method's account was NOT the + # one that just got verified and that we ARE doing post login + # account linking. So this is only for (Case 3) and (Case 4) + + # In both case 3 and case 4, we do not want to change anything in the + # current session in terms of user ID or email verification claim (since + # both of these refer to the current logged in user and not the newly + # linked user's account). + + return None + else: + log_debug_message( + "updateSessionIfRequiredPostEmailVerification got no session" + ) + # the session is updated when the is email verification GET API is called + # so we don't do anything in this API. + return None class EmailVerificationClaimValidators(BooleanClaimValidators): @@ -303,14 +474,20 @@ def is_verified( class EmailVerificationClaimClass(BooleanClaim): def __init__(self): async def fetch_value( - user_id: str, _tenant_id: str, user_context: Dict[str, Any] + _: str, + recipe_user_id: RecipeUserId, + __: str, + ___: Dict[str, Any], + user_context: Dict[str, Any], ) -> bool: - recipe = EmailVerificationRecipe.get_instance() - email_info = await recipe.get_email_for_user_id(user_id, user_context) + recipe = EmailVerificationRecipe.get_instance_or_throw() + email_info = await recipe.get_email_for_recipe_user_id( + None, recipe_user_id, user_context + ) if isinstance(email_info, GetEmailForUserIdOkResult): return await recipe.recipe_implementation.is_email_verified( - user_id, email_info.email, user_context + recipe_user_id, email_info.email, user_context ) if isinstance(email_info, EmailDoesNotExistError): # we consider people without email addresses as validated @@ -336,25 +513,15 @@ async def email_verify_post( ) -> Union[EmailVerifyPostOkResult, EmailVerifyPostInvalidTokenError]: response = await api_options.recipe_implementation.verify_email_using_token( - token, tenant_id, user_context + token, tenant_id, True, user_context ) if isinstance(response, VerifyEmailUsingTokenOkResult): - if session is not None: - try: - await session.fetch_and_set_claim( - EmailVerificationClaim, user_context - ) - except Exception as e: - # This should never happen since we have just set the status above - if str(e) == "UNKNOWN_USER_ID": - log_debug_message( - "verifyEmailPOST: Returning UNAUTHORISED because the user id provided is unknown" - ) - raise_unauthorised_exception("Unknown User ID provided") - else: - raise e + email_verification_recipe = EmailVerificationRecipe.get_instance_or_throw() + new_session = await email_verification_recipe.update_session_if_required_post_email_verification( + api_options.request, session, response.user.recipe_user_id, user_context + ) - return EmailVerifyPostOkResult(response.user) + return EmailVerifyPostOkResult(response.user, new_session) return EmailVerifyPostInvalidTokenError() async def is_email_verified_get( @@ -363,27 +530,40 @@ async def is_email_verified_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> IsEmailVerifiedGetOkResult: - try: - await session.fetch_and_set_claim(EmailVerificationClaim, user_context) - except Exception as e: - if str(e) == "UNKNOWN_USER_ID": - log_debug_message( - "isEmailVerifiedGET: Returning UNAUTHORISED because the user id provided is unknown" - ) - raise_unauthorised_exception("Unknown User ID provided") - else: - raise e - - is_verified = await session.get_claim_value( - EmailVerificationClaim, user_context + recipe = EmailVerificationRecipe.get_instance_or_throw() + email_info = await recipe.get_email_for_recipe_user_id( + None, session.get_recipe_user_id(user_context), user_context ) - if is_verified is None: - raise Exception( - "Should never come here: EmailVerificationClaim failed to set value" + if isinstance(email_info, GetEmailForUserIdOkResult): + is_verified = await api_options.recipe_implementation.is_email_verified( + session.get_recipe_user_id(user_context), email_info.email, user_context ) - return IsEmailVerifiedGetOkResult(is_verified) + if is_verified: + new_session = ( + await recipe.update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + ) + return IsEmailVerifiedGetOkResult(True, new_session) + else: + await session.set_claim_value( + EmailVerificationClaim, False, user_context + ) + return IsEmailVerifiedGetOkResult(False, None) + elif isinstance(email_info, EmailDoesNotExistError): + # We consider people without email addresses as validated + return IsEmailVerifiedGetOkResult(True, None) + else: + # This means that the user ID is not known to supertokens. This could + # happen if the current session's user ID is not an auth user, + # or if it belongs to a recipe user ID that got deleted. Either way, + # we logout the user. + raise_unauthorised_exception("Unknown User ID provided") async def generate_email_verify_token_post( self, @@ -394,22 +574,30 @@ async def generate_email_verify_token_post( GenerateEmailVerifyTokenPostOkResult, GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError, ]: - user_id = session.get_user_id(user_context) - email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id( - user_id, user_context - ) tenant_id = session.get_tenant_id() + email_info = await EmailVerificationRecipe.get_instance_or_throw().get_email_for_recipe_user_id( + None, session.get_recipe_user_id(user_context), user_context + ) + if isinstance(email_info, EmailDoesNotExistError): log_debug_message( "Email verification email not sent to user %s because it doesn't have an email address.", - user_id, + session.get_recipe_user_id(user_context).get_as_string(), ) - return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError() - if isinstance(email_info, GetEmailForUserIdOkResult): + # This can happen if the user ID was found, but it has no email. In this + # case, we treat it as a success case. + new_session = await EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError(new_session) + elif isinstance(email_info, GetEmailForUserIdOkResult): response = ( await api_options.recipe_implementation.create_email_verification_token( - user_id, + session.get_recipe_user_id(user_context), email_info.email, tenant_id, user_context, @@ -419,21 +607,25 @@ async def generate_email_verify_token_post( if isinstance( response, CreateEmailVerificationTokenEmailAlreadyVerifiedError ): - if await session.get_claim_value(EmailVerificationClaim) is not True: - # this can happen if the email was "verified" in another browser - # and this session is still outdated - and the user has not - # called the get email verification API yet. - await session.fetch_and_set_claim( - EmailVerificationClaim, user_context - ) log_debug_message( - "Email verification email not sent to %s because it is already verified.", - email_info.email, + "Email verification email not sent to user %s because it is already verified.", + session.get_recipe_user_id(user_context).get_as_string(), + ) + new_session = await EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + api_options.request, + session, + session.get_recipe_user_id(user_context), + user_context, + ) + return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError( + new_session ) - return GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError() - if await session.get_claim_value(EmailVerificationClaim) is not False: - # this can happen if the email was "unverified" in another browser + if ( + await session.get_claim_value(EmailVerificationClaim, user_context) + is not False + ): + # This can happen if the email was unverified in another browser # and this session is still outdated - and the user has not # called the get email verification API yet. await session.fetch_and_set_claim(EmailVerificationClaim, user_context) @@ -446,9 +638,15 @@ async def generate_email_verify_token_post( user_context, ) - log_debug_message("Sending email verification email to %s", email_info) + log_debug_message( + "Sending email verification email to %s", email_info.email + ) email_verification_email_delivery_input = VerificationEmailTemplateVars( - user=VerificationEmailTemplateVarsUser(user_id, email_info.email), + user=VerificationEmailTemplateVarsUser( + _id=session.get_user_id(user_context), + recipe_user_id=session.get_recipe_user_id(user_context), + email=email_info.email, + ), email_verify_link=email_verify_link, tenant_id=tenant_id, ) diff --git a/supertokens_python/recipe/emailverification/recipe_implementation.py b/supertokens_python/recipe/emailverification/recipe_implementation.py index c4b17f237..ea5ab3fb3 100644 --- a/supertokens_python/recipe/emailverification/recipe_implementation.py +++ b/supertokens_python/recipe/emailverification/recipe_implementation.py @@ -13,7 +13,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Union, Optional, Callable from supertokens_python.normalised_url_path import NormalisedURLPath @@ -25,28 +25,48 @@ UnverifyEmailOkResult, VerifyEmailUsingTokenInvalidTokenError, VerifyEmailUsingTokenOkResult, + GetEmailForUserIdOkResult, + EmailDoesNotExistError, + UnknownUserIdError, ) -from .types import User +from .types import EmailVerificationUser +from supertokens_python.asyncio import get_user +from supertokens_python.types import RecipeUserId, User if TYPE_CHECKING: from supertokens_python.querier import Querier - from .utils import EmailVerificationConfig - class RecipeImplementation(RecipeInterface): - def __init__(self, querier: Querier, config: EmailVerificationConfig): + def __init__( + self, + querier: Querier, + get_email_for_recipe_user_id: Callable[ + [Optional[User], RecipeUserId, Dict[str, Any]], + Awaitable[ + Union[ + GetEmailForUserIdOkResult, + EmailDoesNotExistError, + UnknownUserIdError, + ] + ], + ], + ): super().__init__() self.querier = querier - self.config = config + self.get_email_for_recipe_user_id = get_email_for_recipe_user_id async def create_email_verification_token( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> Union[ CreateEmailVerificationTokenOkResult, CreateEmailVerificationTokenEmailAlreadyVerifiedError, ]: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/user/email/verify/token"), data, @@ -57,7 +77,11 @@ async def create_email_verification_token( return CreateEmailVerificationTokenEmailAlreadyVerifiedError() async def verify_email_using_token( - self, token: str, tenant_id: str, user_context: Dict[str, Any] + self, + token: str, + tenant_id: str, + attempt_account_linking: bool, + user_context: Dict[str, Any], ) -> Union[VerifyEmailUsingTokenOkResult, VerifyEmailUsingTokenInvalidTokenError]: data = {"method": "token", "token": token} response = await self.querier.send_post_request( @@ -65,25 +89,55 @@ async def verify_email_using_token( data, user_context, ) - if "status" in response and response["status"] == "OK": + if response["status"] == "OK": + recipe_user_id = RecipeUserId(response["userId"]) + if attempt_account_linking: + updated_user = await get_user( + recipe_user_id.get_as_string(), user_context + ) + + if updated_user: + # Check if the verified email is currently associated with the user ID + email_info = await self.get_email_for_recipe_user_id( + updated_user, recipe_user_id, user_context + ) + if ( + isinstance(email_info, GetEmailForUserIdOkResult) + and email_info.email == response["email"] + ): + from ..accountlinking.recipe import AccountLinkingRecipe + + account_linking = AccountLinkingRecipe.get_instance() + await account_linking.try_linking_by_account_info_or_create_primary_user( + tenant_id=tenant_id, + input_user=updated_user, + session=None, + user_context=user_context, + ) + return VerifyEmailUsingTokenOkResult( - User(response["userId"], response["email"]) + EmailVerificationUser(recipe_user_id, response["email"]) ) - return VerifyEmailUsingTokenInvalidTokenError() + else: + return VerifyEmailUsingTokenInvalidTokenError() async def is_email_verified( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> bool: - params = {"userId": user_id, "email": email} + params = {"userId": recipe_user_id.get_as_string(), "email": email} response = await self.querier.send_get_request( NormalisedURLPath("/recipe/user/email/verify"), params, user_context ) return response["isVerified"] async def revoke_email_verification_tokens( - self, user_id: str, email: str, tenant_id: str, user_context: Dict[str, Any] + self, + recipe_user_id: RecipeUserId, + email: str, + tenant_id: str, + user_context: Dict[str, Any], ) -> RevokeEmailVerificationTokensOkResult: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/user/email/verify/token/remove"), data, @@ -92,9 +146,9 @@ async def revoke_email_verification_tokens( return RevokeEmailVerificationTokensOkResult() async def unverify_email( - self, user_id: str, email: str, user_context: Dict[str, Any] + self, recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] ) -> UnverifyEmailOkResult: - data = {"userId": user_id, "email": email} + data = {"userId": recipe_user_id.get_as_string(), "email": email} await self.querier.send_post_request( NormalisedURLPath("/recipe/user/email/verify/remove"), data, user_context ) diff --git a/supertokens_python/recipe/emailverification/syncio/__init__.py b/supertokens_python/recipe/emailverification/syncio/__init__.py index 9621cbec9..914f86498 100644 --- a/supertokens_python/recipe/emailverification/syncio/__init__.py +++ b/supertokens_python/recipe/emailverification/syncio/__init__.py @@ -16,11 +16,12 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.recipe.emailverification.types import EmailTemplateVars +from supertokens_python.types import RecipeUserId def create_email_verification_token( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): @@ -29,35 +30,40 @@ def create_email_verification_token( ) return sync( - create_email_verification_token(tenant_id, user_id, email, user_context) + create_email_verification_token(tenant_id, recipe_user_id, email, user_context) ) def verify_email_using_token( tenant_id: str, token: str, + attempt_account_linking: bool = True, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import ( verify_email_using_token, ) - return sync(verify_email_using_token(tenant_id, token, user_context)) + return sync( + verify_email_using_token( + tenant_id, token, attempt_account_linking, user_context + ) + ) def is_email_verified( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import is_email_verified - return sync(is_email_verified(user_id, email, user_context)) + return sync(is_email_verified(recipe_user_id, email, user_context)) def revoke_email_verification_tokens( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, ): @@ -66,18 +72,18 @@ def revoke_email_verification_tokens( ) return sync( - revoke_email_verification_tokens(tenant_id, user_id, email, user_context) + revoke_email_verification_tokens(tenant_id, recipe_user_id, email, user_context) ) def unverify_email( - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import unverify_email - return sync(unverify_email(user_id, email, user_context)) + return sync(unverify_email(recipe_user_id, email, user_context)) def send_email( @@ -91,7 +97,7 @@ def send_email( def create_email_verification_link( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ): @@ -99,12 +105,15 @@ def create_email_verification_link( create_email_verification_link, ) - return sync(create_email_verification_link(tenant_id, user_id, email, user_context)) + return sync( + create_email_verification_link(tenant_id, recipe_user_id, email, user_context) + ) def send_email_verification_email( tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, email: Optional[str], user_context: Optional[Dict[str, Any]] = None, ): @@ -112,4 +121,8 @@ def send_email_verification_email( send_email_verification_email, ) - return sync(send_email_verification_email(tenant_id, user_id, email, user_context)) + return sync( + send_email_verification_email( + tenant_id, user_id, recipe_user_id, email, user_context + ) + ) diff --git a/supertokens_python/recipe/emailverification/types.py b/supertokens_python/recipe/emailverification/types.py index c32b46ddf..dc95cbb54 100644 --- a/supertokens_python/recipe/emailverification/types.py +++ b/supertokens_python/recipe/emailverification/types.py @@ -13,26 +13,41 @@ # under the License. from __future__ import annotations -from typing import Union +from typing import Any, Dict, Union from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import ( SMTPServiceInterface, EmailDeliveryInterface, ) +from supertokens_python.types import RecipeUserId -class User: - def __init__(self, user_id: str, email: str): - self.user_id = user_id +class EmailVerificationUser: + def __init__(self, recipe_user_id: RecipeUserId, email: str): + self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + return { + "recipeUserId": self.recipe_user_id.get_as_string(), + "email": self.email, + } + class VerificationEmailTemplateVarsUser: - def __init__(self, user_id: str, email: str): - self.id = user_id + def __init__(self, _id: str, recipe_user_id: RecipeUserId, email: str): + self.id = _id + self.recipe_user_id = recipe_user_id self.email = email + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "email": self.email, + } + class VerificationEmailTemplateVars: def __init__( diff --git a/supertokens_python/recipe/emailverification/utils.py b/supertokens_python/recipe/emailverification/utils.py index 7dcebb515..71d207e8c 100644 --- a/supertokens_python/recipe/emailverification/utils.py +++ b/supertokens_python/recipe/emailverification/utils.py @@ -55,20 +55,20 @@ def __init__( get_email_delivery_config: Callable[ [], EmailDeliveryConfigWithService[VerificationEmailTemplateVars] ], - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction], + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction], override: OverrideConfig, ): self.mode = mode self.override = override self.get_email_delivery_config = get_email_delivery_config - self.get_email_for_user_id = get_email_for_user_id + self.get_email_for_recipe_user_id = get_email_for_recipe_user_id def validate_and_normalise_user_input( app_info: AppInfo, mode: MODE_TYPE, email_delivery: Union[EmailDeliveryConfig[EmailTemplateVars], None] = None, - get_email_for_user_id: Optional[TypeGetEmailForUserIdFunction] = None, + get_email_for_recipe_user_id: Optional[TypeGetEmailForUserIdFunction] = None, override: Union[OverrideConfig, None] = None, ) -> EmailVerificationConfig: if mode not in ["REQUIRED", "OPTIONAL"]: @@ -98,7 +98,7 @@ def get_email_delivery_config() -> ( return EmailVerificationConfig( mode, get_email_delivery_config, - get_email_for_user_id, + get_email_for_recipe_user_id, override, ) diff --git a/supertokens_python/recipe/multifactorauth/__init__.py b/supertokens_python/recipe/multifactorauth/__init__.py new file mode 100644 index 000000000..fa3aaf6f8 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +from supertokens_python.recipe.multifactorauth.types import OverrideConfig + +from .recipe import MultiFactorAuthRecipe + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + + +def init( + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return MultiFactorAuthRecipe.init( + first_factors, + override, + ) diff --git a/supertokens_python/recipe/multifactorauth/api/__init__.py b/supertokens_python/recipe/multifactorauth/api/__init__.py new file mode 100644 index 000000000..13214e8aa --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from .resync_session_and_fetch_mfa_info import handle_resync_session_and_fetch_mfa_info_api # type: ignore diff --git a/supertokens_python/recipe/multifactorauth/api/implementation.py b/supertokens_python/recipe/multifactorauth/api/implementation.py new file mode 100644 index 000000000..edb888256 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/implementation.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import importlib + +from typing import Any, Dict, List, Union +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.asyncio import get_user +from supertokens_python.recipe.session.exceptions import ( + InvalidClaimsError, + SuperTokensSessionError, + UnauthorisedError, +) + +from supertokens_python.types import GeneralErrorResponse +from ..interfaces import ( + APIInterface, + APIOptions, + NextFactors, + ResyncSessionAndFetchMFAInfoPUTOkResult, +) + + +class APIImplementation(APIInterface): + async def resync_session_and_fetch_mfa_info_put( + self, + api_options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + + session_user = await get_user(session.get_user_id(), user_context) + + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + mfa_info = await module.update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + factors_setup_for_user = ( + await api_options.recipe_implementation.get_factors_setup_for_user( + user=session_user, + user_context=user_context, + ) + ) + tenant_info = await get_tenant( + session.get_tenant_id(user_context), user_context + ) + if tenant_info is None: + raise UnauthorisedError( + "Tenant not found", + ) + all_available_secondary_factors = ( + await api_options.recipe_instance.get_all_available_secondary_factor_ids( + tenant_info + ) + ) + + factors_allowed_to_setup: List[str] = [] + + async def get_factors_set_up_for_user(): + return factors_setup_for_user + + async def get_mfa_requirements_for_auth(): + return mfa_info.mfa_requirements_for_auth + + for factor_id in all_available_secondary_factors: + try: + await api_options.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=factor_id, + factors_set_up_for_user=get_factors_set_up_for_user, + mfa_requirements_for_auth=get_mfa_requirements_for_auth, + user_context=user_context, + ) + factors_allowed_to_setup.append(factor_id) + except SuperTokensSessionError as err: + if not isinstance(err, InvalidClaimsError): + raise err + + next_set_of_unsatisfied_factors = ( + MultiFactorAuthClaim.get_next_set_of_unsatisfied_factors( + mfa_info.completed_factors, mfa_info.mfa_requirements_for_auth + ) + ) + + get_emails_for_factors_result = ( + await api_options.recipe_instance.get_emails_for_factors( + session_user, session.get_recipe_user_id(user_context) + ) + ) + get_phone_numbers_for_factors_result = ( + await api_options.recipe_instance.get_phone_numbers_for_factors( + session_user, session.get_recipe_user_id(user_context) + ) + ) + if ( + get_emails_for_factors_result.status == "UNKNOWN_SESSION_RECIPE_USER_ID" + or get_phone_numbers_for_factors_result.status + == "UNKNOWN_SESSION_RECIPE_USER_ID" + ): + raise UnauthorisedError( + "User no longer associated with the session", + ) + + next_factors = [ + factor_id + for factor_id in next_set_of_unsatisfied_factors.factor_ids + if factor_id in factors_allowed_to_setup + or factor_id in factors_setup_for_user + ] + + if ( + len(next_factors) == 0 + and len(next_set_of_unsatisfied_factors.factor_ids) != 0 + ): + raise Exception( + f"The user is required to complete secondary factors they are not allowed to " + f"({', '.join(next_set_of_unsatisfied_factors.factor_ids)}), likely because of configuration issues." + ) + return ResyncSessionAndFetchMFAInfoPUTOkResult( + factors=NextFactors( + next_=next_factors, + already_setup=factors_setup_for_user, + allowed_to_setup=factors_allowed_to_setup, + ), + emails=get_emails_for_factors_result.factor_id_to_emails_map, + phone_numbers=get_phone_numbers_for_factors_result.factor_id_to_phone_number_map, + ) diff --git a/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py new file mode 100644 index 000000000..8d7f1e8eb --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/api/resync_session_and_fetch_mfa_info.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + +if TYPE_CHECKING: + from ..interfaces import ( + APIOptions, + APIInterface, + ) + +from supertokens_python.utils import send_200_response + +from supertokens_python.recipe.session.asyncio import get_session + + +async def handle_resync_session_and_fetch_mfa_info_api( + _tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +): + if api_implementation.disable_resync_session_and_fetch_mfa_info_put is True: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + assert session is not None + + response = await api_implementation.resync_session_and_fetch_mfa_info_put( + api_options, + session, + user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/multifactorauth/asyncio/__init__.py b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py new file mode 100644 index 000000000..8f51ced5b --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/asyncio/__init__.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Optional, List + +from supertokens_python.recipe.session import SessionContainer + +from ..types import ( + MFARequirementList, +) +from ..utils import update_and_get_mfa_related_info_in_session +from supertokens_python.recipe.accountlinking.asyncio import get_user + + +async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + factors_set_up_for_user = await get_factors_setup_for_user( + session.get_user_id(), user_context + ) + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + + async def func_factors_set_up_for_user(): + return factors_set_up_for_user + + async def func_mfa_requirements_for_auth(): + return mfa_info.mfa_requirements_for_auth + + await recipe.recipe_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=factor_id, + factors_set_up_for_user=func_factors_set_up_for_user, + mfa_requirements_for_auth=func_mfa_requirements_for_auth, + user_context=user_context, + ) + + +async def get_mfa_requirements_for_auth( + session: SessionContainer, + user_context: Optional[Dict[str, Any]] = None, +) -> MFARequirementList: + if user_context is None: + user_context = {} + + mfa_info = await update_and_get_mfa_related_info_in_session( + input_session=session, + user_context=user_context, + ) + + return mfa_info.mfa_requirements_for_auth + + +async def mark_factor_as_complete_in_session( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id=factor_id, + user_context=user_context, + ) + + +async def get_factors_setup_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + user = await get_user(user_id, user_context) + if user is None: + raise Exception("Unknown user id") + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + return await recipe.recipe_implementation.get_factors_setup_for_user( + user=user, + user_context=user_context, + ) + + +async def get_required_secondary_factors_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + return await recipe.recipe_implementation.get_required_secondary_factors_for_user( + user_id=user_id, + user_context=user_context, + ) + + +async def add_to_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.add_to_required_secondary_factors_for_user( + user_id=user_id, + factor_id=factor_id, + user_context=user_context, + ) + + +async def remove_from_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + from ..recipe import MultiFactorAuthRecipe + + recipe = MultiFactorAuthRecipe.get_instance_or_throw_error() + await recipe.recipe_implementation.remove_from_required_secondary_factors_for_user( + user_id=user_id, + factor_id=factor_id, + user_context=user_context, + ) diff --git a/supertokens_python/recipe/emailverification/ev_claim_validators.py b/supertokens_python/recipe/multifactorauth/constants.py similarity index 93% rename from supertokens_python/recipe/emailverification/ev_claim_validators.py rename to supertokens_python/recipe/multifactorauth/constants.py index dd5f414fc..5f7a38023 100644 --- a/supertokens_python/recipe/emailverification/ev_claim_validators.py +++ b/supertokens_python/recipe/multifactorauth/constants.py @@ -11,4 +11,4 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from __future__ import annotations +RESYNC_SESSION_AND_FETCH_MFA_INFO = "/mfa/info" diff --git a/supertokens_python/recipe/multifactorauth/interfaces.py b/supertokens_python/recipe/multifactorauth/interfaces.py new file mode 100644 index 000000000..f960e9ffe --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/interfaces.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Dict, Any, Union, List, Callable, Awaitable, TYPE_CHECKING +from ...types import APIResponse, GeneralErrorResponse + +if TYPE_CHECKING: + from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.session import SessionContainer + from .types import MFARequirementList, MultiFactorAuthConfig + from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe + from ...supertokens import AppInfo + from supertokens_python.types import User + + +class RecipeInterface(ABC): + @abstractmethod + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + self, + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> None: + pass + + @abstractmethod + async def get_mfa_requirements_for_auth( + self, + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + pass + + @abstractmethod + async def mark_factor_as_complete_in_session( + self, + session: SessionContainer, + factor_id: str, + user_context: Dict[str, Any], + ) -> None: + pass + + @abstractmethod + async def get_factors_setup_for_user( + self, user: User, user_context: Dict[str, Any] + ) -> List[str]: + pass + + @abstractmethod + async def get_required_secondary_factors_for_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> List[str]: + pass + + @abstractmethod + async def add_to_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ) -> None: + pass + + @abstractmethod + async def remove_from_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ) -> None: + pass + + +class APIOptions: + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: MultiFactorAuthConfig, + recipe_implementation: RecipeInterface, + app_info: AppInfo, + recipe_instance: MultiFactorAuthRecipe, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config = config + self.recipe_implementation: RecipeInterface = recipe_implementation + self.app_info = app_info + self.recipe_instance = recipe_instance + + +class APIInterface: + def __init__(self): + self.disable_resync_session_and_fetch_mfa_info_put = False + + @abstractmethod + async def resync_session_and_fetch_mfa_info_put( + self, + api_options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + pass + + +class NextFactors: + def __init__( + self, next_: List[str], already_setup: List[str], allowed_to_setup: List[str] + ): + self.next_ = next_ + self.already_setup = already_setup + self.allowed_to_setup = allowed_to_setup + + def to_json(self) -> Dict[str, Any]: + return { + "next": self.next_, + "alreadySetup": self.already_setup, + "allowedToSetup": self.allowed_to_setup, + } + + +class ResyncSessionAndFetchMFAInfoPUTOkResult(APIResponse): + def __init__( + self, + factors: NextFactors, + emails: Dict[str, List[str]], + phone_numbers: Dict[str, List[str]], + ): + self.factors = factors + self.emails = emails + self.phone_numbers = phone_numbers + + status: str = "OK" + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "factors": self.factors.to_json(), + "emails": self.emails, + "phoneNumbers": self.phone_numbers, + } diff --git a/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py new file mode 100644 index 000000000..8b45108ba --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/multi_factor_auth_claim.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import importlib + +from typing import Any, Dict, Optional, Set + +from supertokens_python.recipe.session.interfaces import ( + ClaimValidationResult, + JSONObject, + SessionClaim, + SessionClaimValidator, +) +from supertokens_python.types import RecipeUserId + +from .types import ( + FactorIdsAndType, + MFAClaimValue, + MFARequirementList, +) + + +class HasCompletedRequirementListSCV(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + requirement_list: MFARequirementList, + ): + super().__init__(id_) + self.claim: MultiFactorAuthClaimClass = claim + self.requirement_list = requirement_list + + def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + return bool(self.claim.key not in payload or not payload[self.claim.key]) + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + if len(self.requirement_list) == 0: + return ClaimValidationResult(is_valid=True) # no requirements to satisfy + + if (self.claim.key not in payload) or (not payload[self.claim.key]): + raise Exception( + "This should never happen, claim value not present in payload" + ) + + claim_val: MFAClaimValue = MFAClaimValue( + c=payload[self.claim.key]["c"], v=payload[self.claim.key]["v"] + ) + + completed_factors = claim_val.c + next_set_of_unsatisfied_factors = ( + self.claim.get_next_set_of_unsatisfied_factors( + completed_factors, self.requirement_list + ) + ) + + if len(next_set_of_unsatisfied_factors.factor_ids) == 0: + return ClaimValidationResult( + is_valid=True + ) # No item in the requirementList is left unsatisfied, hence is Valid + + factor_ids = next_set_of_unsatisfied_factors.factor_ids + + if next_set_of_unsatisfied_factors.type_ == "string": + return ClaimValidationResult( + is_valid=False, + reason={ + "message": f"Factor validation failed: {factor_ids[0]} not completed", + "factor_id": factor_ids[0], + }, + ) + + elif next_set_of_unsatisfied_factors.type_ == "oneOf": + return ClaimValidationResult( + is_valid=False, + reason={ + "reason": f"None of these factors are complete in the session: {', '.join(factor_ids)}", + "one_of": factor_ids, + }, + ) + else: + return ClaimValidationResult( + is_valid=False, + reason={ + "reason": f"Some of the factors are not complete in the session: {', '.join(factor_ids)}", + "all_of_in_any_order": factor_ids, + }, + ) + + +class HasCompletedMFARequirementsForAuthSCV(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + ): + super().__init__(id_) + self.claim = claim + + def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + assert self.claim is not None + return bool(self.claim.key not in payload or not payload[self.claim.key]) + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + assert self.claim is not None + if self.claim.key not in payload or not payload[self.claim.key]: + raise Exception( + "This should never happen, claim value not present in payload" + ) + claim_val: MFAClaimValue = MFAClaimValue( + c=payload[self.claim.key]["c"], v=payload[self.claim.key]["v"] + ) + + return ClaimValidationResult( + is_valid=claim_val.v, + reason=( + { + "message": "MFA requirement for auth is not satisfied", + } + if not claim_val.v + else None + ), + ) + + +class MultiFactorAuthClaimValidators: + def __init__(self, claim: MultiFactorAuthClaimClass): + self.claim = claim + + def has_completed_requirement_list( + self, requirement_list: MFARequirementList, claim_key: Optional[str] = None + ) -> SessionClaimValidator: + return HasCompletedRequirementListSCV( + id_=claim_key or self.claim.key, + claim=self.claim, + requirement_list=requirement_list, + ) + + def has_completed_mfa_requirements_for_auth( + self, claim_key: Optional[str] = None + ) -> SessionClaimValidator: + + return HasCompletedMFARequirementsForAuthSCV( + id_=claim_key or self.claim.key, + claim=self.claim, + ) + + +class MultiFactorAuthClaimClass(SessionClaim[MFAClaimValue]): + def __init__(self, key: Optional[str] = None): + key = key or "st-mfa" + + async def fetch_value( + _user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> MFAClaimValue: + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + + mfa_info = await module.update_and_get_mfa_related_info_in_session( + input_session_recipe_user_id=recipe_user_id, + input_tenant_id=tenant_id, + input_access_token_payload=current_payload, + user_context=user_context, + ) + return MFAClaimValue( + c=mfa_info.completed_factors, + v=mfa_info.is_mfa_requirements_for_auth_satisfied, + ) + + super().__init__(key or "st-mfa", fetch_value=fetch_value) + self.validators = MultiFactorAuthClaimValidators(claim=self) + + def get_next_set_of_unsatisfied_factors( + self, completed_factors: Dict[str, int], requirement_list: MFARequirementList + ) -> FactorIdsAndType: + for req in requirement_list: + next_factors: Set[str] = set() + factor_type = "string" + + if isinstance(req, str): + if req not in completed_factors: + factor_type = "string" + next_factors.add(req) + else: + if "oneOf" in req: + satisfied = any( + factor_id in completed_factors for factor_id in req["oneOf"] + ) + if not satisfied: + factor_type = "oneOf" + next_factors.update(req["oneOf"]) + elif "allOfInAnyOrder" in req: + factor_type = "allOfInAnyOrder" + next_factors.update( + factor_id + for factor_id in req["allOfInAnyOrder"] + if factor_id not in completed_factors + ) + + if len(next_factors) > 0: + return FactorIdsAndType( + factor_ids=list(next_factors), type_=factor_type + ) + + return FactorIdsAndType(factor_ids=[], type_="string") + + def add_to_payload_( + self, + payload: JSONObject, + value: MFAClaimValue, + user_context: Optional[Dict[str, Any]] = None, + ) -> JSONObject: + prev_value = payload.get(self.key, {}) + return { + **payload, + self.key: { + "c": {**prev_value.get("c", {}), **value.c}, + "v": value.v, + }, + } + + def remove_from_payload( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> JSONObject: + del payload[self.key] + return payload + + def remove_from_payload_by_merge_( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> JSONObject: + payload[self.key] = None + return payload + + def get_value_from_payload( + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None + ) -> Optional[MFAClaimValue]: + value = payload.get(self.key) + if value is None: + return None + return MFAClaimValue(c=value["c"], v=value["v"]) + + +MultiFactorAuthClaim = MultiFactorAuthClaimClass() diff --git a/supertokens_python/recipe/multifactorauth/recipe.py b/supertokens_python/recipe/multifactorauth/recipe.py new file mode 100644 index 000000000..139826761 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/recipe.py @@ -0,0 +1,280 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import importlib + +from os import environ +from typing import Any, Dict, Optional, List, Union + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception +from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.post_init_callbacks import PostSTInitCallbacks +from supertokens_python.querier import Querier +from supertokens_python.recipe.multifactorauth.api import ( + resync_session_and_fetch_mfa_info, +) +from supertokens_python.recipe.multifactorauth.constants import ( + RESYNC_SESSION_AND_FETCH_MFA_INFO, +) +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.supertokens import AppInfo +from supertokens_python.types import User, RecipeUserId +from .types import ( + OverrideConfig, + GetFactorsSetupForUserFromOtherRecipesFunc, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetPhoneNumbersForFactorsFromOtherRecipesFunc, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + GetEmailsForFactorOkResult, + GetPhoneNumbersForFactorsOkResult, +) +from .interfaces import APIOptions + + +class MultiFactorAuthRecipe(RecipeModule): + recipe_id = "multifactorauth" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, + ): + super().__init__(recipe_id, app_info) + self.get_factors_setup_for_user_from_other_recipes_funcs: List[ + GetFactorsSetupForUserFromOtherRecipesFunc + ] = [] + self.get_all_available_secondary_factor_ids_from_other_recipes_funcs: List[ + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc + ] = [] + self.get_emails_for_factor_from_other_recipes_funcs: List[ + GetEmailsForFactorFromOtherRecipesFunc + ] = [] + self.get_phone_numbers_for_factor_from_other_recipes_funcs: List[ + GetPhoneNumbersForFactorsFromOtherRecipesFunc + ] = [] + self.is_get_mfa_requirements_for_auth_overridden: bool = False + + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + + self.config = module.validate_and_normalise_user_input( + first_factors, + override, + ) + from .recipe_implementation import RecipeImplementation + + recipe_implementation = RecipeImplementation( + Querier.get_instance(recipe_id), self + ) + self.recipe_implementation = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + from .api.implementation import APIImplementation + + api_implementation = APIImplementation() + self.api_implementation = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) + + def callback(): + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + + mt_recipe = MultitenancyRecipe.get_instance() + mt_recipe.static_first_factors = self.config.first_factors + + SessionRecipe.get_instance().add_claim_validator_from_other_recipe( + MultiFactorAuthClaim.validators.has_completed_mfa_requirements_for_auth() + ) + + PostSTInitCallbacks.add_post_init_callback(callback) + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [ + APIHandled( + method="put", + path_without_api_base_path=NormalisedURLPath( + RESYNC_SESSION_AND_FETCH_MFA_INFO + ), + request_id=RESYNC_SESSION_AND_FETCH_MFA_INFO, + disabled=self.api_implementation.disable_resync_session_and_fetch_mfa_info_put, + ) + ] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ): + options = APIOptions( + request, + response, + self.get_recipe_id(), + self.config, + self.recipe_implementation, + self.get_app_info(), + self, + ) + return await resync_session_and_fetch_mfa_info.handle_resync_session_and_fetch_mfa_info_api( + tenant_id, self.api_implementation, options, user_context + ) + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + first_factors: Optional[List[str]] = None, + override: Union[OverrideConfig, None] = None, + ): + def func(app_info: AppInfo): + if MultiFactorAuthRecipe.__instance is None: + MultiFactorAuthRecipe.__instance = MultiFactorAuthRecipe( + MultiFactorAuthRecipe.recipe_id, + app_info, + first_factors, + override, + ) + return MultiFactorAuthRecipe.__instance + raise_general_exception( + "MultiFactorAuthRecipe recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance_or_throw_error() -> MultiFactorAuthRecipe: + if MultiFactorAuthRecipe.__instance is not None: + return MultiFactorAuthRecipe.__instance + raise_general_exception( + "MultiFactorAuth recipe initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def get_instance() -> Optional[MultiFactorAuthRecipe]: + return MultiFactorAuthRecipe.__instance + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + MultiFactorAuthRecipe.__instance = None + + def add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + self, func: GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc + ): + self.get_all_available_secondary_factor_ids_from_other_recipes_funcs.append( + func + ) + + async def get_all_available_secondary_factor_ids( + self, tenant_config: TenantConfig + ) -> List[str]: + factor_ids: List[str] = [] + for ( + func + ) in self.get_all_available_secondary_factor_ids_from_other_recipes_funcs: + factor_ids_res = await func.func(tenant_config) + for factor_id in factor_ids_res: + if factor_id not in factor_ids: + factor_ids.append(factor_id) + return factor_ids + + def add_func_to_get_factors_setup_for_user_from_other_recipes( + self, func: GetFactorsSetupForUserFromOtherRecipesFunc + ): + self.get_factors_setup_for_user_from_other_recipes_funcs.append(func) + + def add_func_to_get_emails_for_factor_from_other_recipes( + self, func: GetEmailsForFactorFromOtherRecipesFunc + ): + self.get_emails_for_factor_from_other_recipes_funcs.append(func) + + async def get_emails_for_factors( + self, user: User, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ]: + + factorIdToEmailsMap: Dict[str, List[str]] = {} + + for func in self.get_emails_for_factor_from_other_recipes_funcs: + func_result = await func.func(user, session_recipe_user_id) + if isinstance( + func_result, GetEmailsForFactorUnknownSessionRecipeUserIdResult + ): + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() + factorIdToEmailsMap.update(func_result.factor_id_to_emails_map) + + return GetEmailsForFactorOkResult(factor_id_to_emails_map=factorIdToEmailsMap) + + def add_func_to_get_phone_numbers_for_factors_from_other_recipes( + self, func: GetPhoneNumbersForFactorsFromOtherRecipesFunc + ): + self.get_phone_numbers_for_factor_from_other_recipes_funcs.append(func) + + async def get_phone_numbers_for_factors( + self, user: User, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + ]: + factorIdToPhoneNumberMap: Dict[str, List[str]] = {} + + for func in self.get_phone_numbers_for_factor_from_other_recipes_funcs: + func_result = await func.func(user, session_recipe_user_id) + if isinstance( + func_result, GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult + ): + return GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult() + factorIdToPhoneNumberMap.update(func_result.factor_id_to_phone_number_map) + + return GetPhoneNumbersForFactorsOkResult( + factor_id_to_phone_number_map=factorIdToPhoneNumberMap + ) diff --git a/supertokens_python/recipe/multifactorauth/recipe_implementation.py b/supertokens_python/recipe/multifactorauth/recipe_implementation.py new file mode 100644 index 000000000..476582eab --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/recipe_implementation.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import importlib + +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Set, Callable, List + +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + MultiFactorAuthClaimClass, +) +from supertokens_python.recipe.session.interfaces import ( + ClaimValidationResult, + JSONObject, + SessionClaimValidator, +) +from supertokens_python.recipe.usermetadata.asyncio import ( + get_user_metadata, + update_user_metadata, +) +from supertokens_python.recipe.multifactorauth.types import ( + MFAClaimValue, + MFARequirementList, +) +from supertokens_python.recipe.session import SessionContainer + +from supertokens_python.types import User +from .interfaces import RecipeInterface + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + from .recipe import MultiFactorAuthRecipe + + +class Validator(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: MultiFactorAuthClaimClass, + mfa_requirement_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + factor_id: str, + ): + super().__init__(id_) + self.claim: MultiFactorAuthClaimClass = claim + self.factors_set_up_for_user = factors_set_up_for_user + self.factor_id = factor_id + self.mfa_requirement_for_auth = mfa_requirement_for_auth + + def should_refetch( + self, payload: Dict[str, Any], user_context: Dict[str, Any] + ) -> bool: + return self.claim.get_value_from_payload(payload) is None + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + claim_val: MFAClaimValue | None = self.claim.get_value_from_payload(payload) + + if claim_val is None: + raise Exception( + "This should never happen, claim value not present in payload" + ) + + if claim_val.v: + # Session already satisfied auth requirements + return ClaimValidationResult(is_valid=True) + + set_of_unsatisfied_factors = self.claim.get_next_set_of_unsatisfied_factors( + claim_val.c, await self.mfa_requirement_for_auth() + ) + + factors_set_up_for_user = await self.factors_set_up_for_user() + + if any( + factor_id in factors_set_up_for_user + for factor_id in set_of_unsatisfied_factors.factor_ids + ): + return ClaimValidationResult( + is_valid=False, + reason={ + "message": "Completed factors in the session does not satisfy the MFA requirements for auth", + }, + ) + + if ( + set_of_unsatisfied_factors.factor_ids + and self.factor_id not in set_of_unsatisfied_factors.factor_ids + ): + return ClaimValidationResult( + is_valid=False, + reason={ + "message": "Not allowed to setup factor that is not in the next set of unsatisfied factors", + }, + ) + + return ClaimValidationResult(is_valid=True) + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + recipe_instance: MultiFactorAuthRecipe, + ): + super().__init__() + self.querier = querier + self.recipe_instance = recipe_instance + + async def get_factors_setup_for_user( + self, user: User, user_context: Dict[str, Any] + ) -> List[str]: + factor_ids: List[str] = [] + for ( + func + ) in self.recipe_instance.get_factors_setup_for_user_from_other_recipes_funcs: + result = await func.func(user, user_context) + for factor_id in result: + if factor_id not in factor_ids: + factor_ids.append(factor_id) + return factor_ids + + async def get_mfa_requirements_for_auth( + self, + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + all_factors: Set[str] = set() + for factor in await required_secondary_factors_for_user(): + all_factors.add(factor) + for factor in await required_secondary_factors_for_tenant(): + all_factors.add(factor) + return [{"oneOf": list(all_factors)}] + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + self, + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + await session.assert_claims( + [ + Validator( + id_=MultiFactorAuthClaim.key, + claim=MultiFactorAuthClaim, + mfa_requirement_for_auth=mfa_requirements_for_auth, + factors_set_up_for_user=factors_set_up_for_user, + factor_id=factor_id, + ) + ], + user_context, + ) + + async def mark_factor_as_complete_in_session( + self, session: SessionContainer, factor_id: str, user_context: Dict[str, Any] + ): + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) + + await module.update_and_get_mfa_related_info_in_session( + input_session=session, + input_updated_factor_id=factor_id, + user_context=user_context, + ) + + async def get_required_secondary_factors_for_user( + self, user_id: str, user_context: Dict[str, Any] + ) -> List[str]: + metadata = await get_user_metadata(user_id, user_context) + result: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + return result + + async def add_to_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ): + metadata = await get_user_metadata(user_id, user_context) + factor_ids: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + if factor_id not in factor_ids: + factor_ids.append(factor_id) + metadata_update = { + **metadata.metadata, + "_supertokens": { + **metadata.metadata.get("_supertokens", {}), + "requiredSecondaryFactors": factor_ids, + }, + } + await update_user_metadata(user_id, metadata_update, user_context) + + async def remove_from_required_secondary_factors_for_user( + self, user_id: str, factor_id: str, user_context: Dict[str, Any] + ): + metadata = await get_user_metadata(user_id, user_context) + factor_ids: List[str] = metadata.metadata.get("_supertokens", {}).get( + "requiredSecondaryFactors", [] + ) + if factor_id in factor_ids: + factor_ids = [id for id in factor_ids if id != factor_id] + metadata_update = { + **metadata.metadata, + "_supertokens": { + **metadata.metadata.get("_supertokens", {}), + "requiredSecondaryFactors": factor_ids, + }, + } + await update_user_metadata(user_id, metadata_update, user_context) diff --git a/supertokens_python/recipe/multifactorauth/syncio/__init__.py b/supertokens_python/recipe/multifactorauth/syncio/__init__.py new file mode 100644 index 000000000..8268fd19a --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/syncio/__init__.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Optional, List + +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.async_to_sync_wrapper import sync + + +def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error as async_func, + ) + + return sync(async_func(session, factor_id, user_context)) + + +def get_mfa_requirements_for_auth( + session: SessionContainer, + user_context: Optional[Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_mfa_requirements_for_auth as async_func, + ) + + return sync(async_func(session, user_context)) + + +def mark_factor_as_complete_in_session( + session: SessionContainer, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + mark_factor_as_complete_in_session as async_func, + ) + + return sync(async_func(session, factor_id, user_context)) + + +def get_factors_setup_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_factors_setup_for_user as async_func, + ) + + return sync(async_func(user_id, user_context)) + + +def get_required_secondary_factors_for_user( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> List[str]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + get_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, user_context)) + + +def add_to_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, factor_id, user_context)) + + +def remove_from_required_secondary_factors_for_user( + user_id: str, + factor_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.multifactorauth.asyncio import ( + remove_from_required_secondary_factors_for_user as async_func, + ) + + return sync(async_func(user_id, factor_id, user_context)) diff --git a/supertokens_python/recipe/multifactorauth/types.py b/supertokens_python/recipe/multifactorauth/types.py new file mode 100644 index 000000000..1ffc2ec9f --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/types.py @@ -0,0 +1,146 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +from typing import Awaitable, Dict, Any, Union, List, Optional, Callable, TYPE_CHECKING + +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from typing_extensions import Literal +from supertokens_python.types import User, RecipeUserId + +if TYPE_CHECKING: + from .interfaces import RecipeInterface, APIInterface + + +MFARequirementList = List[ + Union[str, Dict[Union[Literal["oneOf"], Literal["allOfInAnyOrder"]], List[str]]] +] + + +class MFAClaimValue: + c: Dict[str, int] + v: bool + + def __init__(self, c: Dict[str, Any], v: bool): + self.c = c + self.v = v + + +class OverrideConfig: + def __init__( + self, + functions: Union[Callable[[RecipeInterface], RecipeInterface], None] = None, + apis: Union[Callable[[APIInterface], APIInterface], None] = None, + ): + self.functions = functions + self.apis = apis + + +class MultiFactorAuthConfig: + def __init__( + self, + first_factors: Optional[List[str]], + override: OverrideConfig, + ): + self.first_factors = first_factors + self.override = override + + +class FactorIds: + EMAILPASSWORD: Literal["emailpassword"] = "emailpassword" + OTP_EMAIL: Literal["otp-email"] = "otp-email" + OTP_PHONE: Literal["otp-phone"] = "otp-phone" + LINK_EMAIL: Literal["link-email"] = "link-email" + LINK_PHONE: Literal["link-phone"] = "link-phone" + THIRDPARTY: Literal["thirdparty"] = "thirdparty" + TOTP: Literal["totp"] = "totp" + + +class FactorIdsAndType: + def __init__( + self, + factor_ids: List[str], + type_: Union[Literal["string"], Literal["oneOf"], Literal["allOfInAnyOrder"]], + ): + self.factor_ids = factor_ids + self.type_ = type_ + + +class GetFactorsSetupForUserFromOtherRecipesFunc: + def __init__( + self, + func: Callable[[User, Dict[str, Any]], Awaitable[List[str]]], + ): + self.func = func + + +class GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc: + def __init__( + self, + func: Callable[[TenantConfig], Awaitable[List[str]]], + ): + self.func = func + + +class GetEmailsForFactorOkResult: + status: Literal["OK"] = "OK" + + def __init__(self, factor_id_to_emails_map: Dict[str, List[str]]): + self.factor_id_to_emails_map = factor_id_to_emails_map + + +class GetEmailsForFactorUnknownSessionRecipeUserIdResult: + status: Literal["UNKNOWN_SESSION_RECIPE_USER_ID"] = "UNKNOWN_SESSION_RECIPE_USER_ID" + + +class GetEmailsForFactorFromOtherRecipesFunc: + def __init__( + self, + func: Callable[ + [User, RecipeUserId], + Awaitable[ + Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ] + ], + ], + ): + self.func = func + + +class GetPhoneNumbersForFactorsOkResult: + status: Literal["OK"] = "OK" + + def __init__(self, factor_id_to_phone_number_map: Dict[str, List[str]]): + self.factor_id_to_phone_number_map = factor_id_to_phone_number_map + + +class GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult: + status: Literal["UNKNOWN_SESSION_RECIPE_USER_ID"] = "UNKNOWN_SESSION_RECIPE_USER_ID" + + +class GetPhoneNumbersForFactorsFromOtherRecipesFunc: + def __init__( + self, + func: Callable[ + [User, RecipeUserId], + Awaitable[ + Union[ + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + ] + ], + ], + ): + self.func = func diff --git a/supertokens_python/recipe/multifactorauth/utils.py b/supertokens_python/recipe/multifactorauth/utils.py new file mode 100644 index 000000000..12f9ea0f6 --- /dev/null +++ b/supertokens_python/recipe/multifactorauth/utils.py @@ -0,0 +1,314 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations +from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) +from supertokens_python.recipe.multitenancy.asyncio import get_tenant +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.asyncio import get_session_information +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.types import ( + MFAClaimValue, + MFARequirementList, + FactorIds, +) +from supertokens_python.types import RecipeUserId +import math +import time +from typing_extensions import Literal +from supertokens_python.utils import log_debug_message + +if TYPE_CHECKING: + from .types import OverrideConfig, MultiFactorAuthConfig + + +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. +def validate_and_normalise_user_input( + first_factors: Optional[List[str]], + override: Union[OverrideConfig, None] = None, +) -> MultiFactorAuthConfig: + if first_factors is not None and len(first_factors) == 0: + raise ValueError("'first_factors' can be either None or a non-empty list") + + from .types import OverrideConfig as OC, MultiFactorAuthConfig as MFAC + + if override is None: + override = OC() + + return MFAC( + first_factors=first_factors, + override=override, + ) + + +class UpdateAndGetMFARelatedInfoInSessionResult: + def __init__( + self, + completed_factors: Dict[str, int], + mfa_requirements_for_auth: MFARequirementList, + is_mfa_requirements_for_auth_satisfied: bool, + ): + self.completed_factors = completed_factors + self.mfa_requirements_for_auth = mfa_requirements_for_auth + self.is_mfa_requirements_for_auth_satisfied = ( + is_mfa_requirements_for_auth_satisfied + ) + + +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. +async def update_and_get_mfa_related_info_in_session( + user_context: Dict[str, Any], + input_session_recipe_user_id: Optional[RecipeUserId] = None, + input_tenant_id: Optional[str] = None, + input_access_token_payload: Optional[Dict[str, Any]] = None, + input_session: Optional[SessionContainer] = None, + input_updated_factor_id: Optional[str] = None, +) -> UpdateAndGetMFARelatedInfoInSessionResult: + from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe as Recipe, + ) + + session_recipe_user_id: RecipeUserId + tenant_id: str + access_token_payload: Dict[str, Any] + session_handle: str + + if input_session is not None: + session_recipe_user_id = input_session.get_recipe_user_id(user_context) + tenant_id = input_session.get_tenant_id(user_context) + access_token_payload = input_session.get_access_token_payload(user_context) + session_handle = input_session.get_handle(user_context) + else: + assert input_session_recipe_user_id is not None + assert input_tenant_id is not None + assert input_access_token_payload is not None + session_recipe_user_id = input_session_recipe_user_id + tenant_id = input_tenant_id + access_token_payload = input_access_token_payload + session_handle = access_token_payload["sessionHandle"] + + updated_claim_val = False + mfa_claim_value = MultiFactorAuthClaim.get_value_from_payload(access_token_payload) + + if input_updated_factor_id is not None: + if mfa_claim_value is None: + updated_claim_val = True + mfa_claim_value = MFAClaimValue( + c={input_updated_factor_id: math.floor(time.time())}, + v=True, # updated later in the function + ) + else: + updated_claim_val = True + mfa_claim_value.c[input_updated_factor_id] = math.floor(time.time()) + + if mfa_claim_value is None: + session_user = ( + await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + session_recipe_user_id.get_as_string(), user_context + ) + ) + if session_user is None: + raise UnauthorisedError("Session user not found") + + session_info = await get_session_information(session_handle, user_context) + if session_info is None: + raise UnauthorisedError("Session not found") + + first_factor_time = session_info.time_created + computed_first_factor_id_for_session = None + + for login_method in session_user.login_methods: + if ( + login_method.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ): + if login_method.recipe_id == "emailpassword": + valid_res = await is_valid_first_factor( + tenant_id, FactorIds.EMAILPASSWORD, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = FactorIds.EMAILPASSWORD + break + elif login_method.recipe_id == "thirdparty": + valid_res = await is_valid_first_factor( + tenant_id, FactorIds.THIRDPARTY, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = FactorIds.THIRDPARTY + break + else: + factors_to_check: List[str] = [] + if login_method.email is not None: + factors_to_check.extend( + [FactorIds.LINK_EMAIL, FactorIds.OTP_EMAIL] + ) + if login_method.phone_number is not None: + factors_to_check.extend( + [FactorIds.LINK_PHONE, FactorIds.OTP_PHONE] + ) + + for factor_id in factors_to_check: + valid_res = await is_valid_first_factor( + tenant_id, factor_id, user_context + ) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise UnauthorisedError("Tenant not found") + elif valid_res == "OK": + computed_first_factor_id_for_session = factor_id + break + + if computed_first_factor_id_for_session is not None: + break + + if computed_first_factor_id_for_session is None: + raise UnauthorisedError("Incorrect login method used") + + updated_claim_val = True + mfa_claim_value = MFAClaimValue( + c={computed_first_factor_id_for_session: first_factor_time}, + v=True, # updated later in this function + ) + + completed_factors = mfa_claim_value.c + + async def user_getter(): + resp = await AccountLinkingRecipe.get_instance().recipe_implementation.get_user( + session_recipe_user_id.get_as_string(), user_context + ) + if resp is None: + raise UnauthorisedError("Session user not found") + return resp + + async def get_required_secondary_factors_for_tenant( + tenant_id: str, user_context: Dict[str, Any] + ) -> List[str]: + tenant_info = await get_tenant(tenant_id, user_context) + if tenant_info is None: + raise UnauthorisedError("Tenant not found") + return ( + tenant_info.required_secondary_factors + if tenant_info.required_secondary_factors is not None + else [] + ) + + async def get_factors_setup_for_user() -> List[str]: + return await Recipe.get_instance_or_throw_error().recipe_implementation.get_factors_setup_for_user( + user=(await user_getter()), user_context=user_context + ) + + async def get_required_secondary_factors_for_user() -> List[str]: + return await Recipe.get_instance_or_throw_error().recipe_implementation.get_required_secondary_factors_for_user( + user_id=(await user_getter()).id, user_context=user_context + ) + + async def get_required_secondary_factors_for_tenant_helper() -> List[str]: + return await get_required_secondary_factors_for_tenant( + tenant_id=tenant_id, user_context=user_context + ) + + mfa_requirements_for_auth = await Recipe.get_instance_or_throw_error().recipe_implementation.get_mfa_requirements_for_auth( + tenant_id=tenant_id, + access_token_payload=access_token_payload, + user=user_getter, + factors_set_up_for_user=get_factors_setup_for_user, + required_secondary_factors_for_user=get_required_secondary_factors_for_user, + required_secondary_factors_for_tenant=get_required_secondary_factors_for_tenant_helper, + completed_factors=completed_factors, + user_context=user_context, + ) + + are_auth_reqs_complete = ( + len( + MultiFactorAuthClaim.get_next_set_of_unsatisfied_factors( + completed_factors, mfa_requirements_for_auth + ).factor_ids + ) + == 0 + ) + + if mfa_claim_value.v != are_auth_reqs_complete: + updated_claim_val = True + mfa_claim_value.v = are_auth_reqs_complete + + if input_session is not None and updated_claim_val: + await input_session.set_claim_value( + MultiFactorAuthClaim, mfa_claim_value, user_context + ) + + return UpdateAndGetMFARelatedInfoInSessionResult( + completed_factors=completed_factors, + mfa_requirements_for_auth=mfa_requirements_for_auth, + is_mfa_requirements_for_auth_satisfied=mfa_claim_value.v, + ) + + +# IMPORTANT: If this function signature is modified, please update all tha places where this function is called. +# There will be no type errors cause we use importLib to dynamically import if to prevent cyclic import issues. +async def is_valid_first_factor( + tenant_id: str, factor_id: str, user_context: Dict[str, Any] +) -> Literal["OK", "INVALID_FIRST_FACTOR_ERROR", "TENANT_NOT_FOUND_ERROR"]: + + mt_recipe = MultitenancyRecipe.get_instance() + tenant_info = await get_tenant(tenant_id=tenant_id, user_context=user_context) + if tenant_info is None: + return "TENANT_NOT_FOUND_ERROR" + + tenant_config = tenant_info + + first_factors_from_mfa = mt_recipe.static_first_factors + + log_debug_message( + f"is_valid_first_factor got {', '.join(tenant_config.first_factors) if tenant_config.first_factors else None} from tenant config" + ) + log_debug_message(f"is_valid_first_factor got {first_factors_from_mfa} from MFA") + + configured_first_factors: Union[List[str], None] = ( + tenant_config.first_factors or first_factors_from_mfa + ) + + if configured_first_factors is None: + configured_first_factors = mt_recipe.all_available_first_factors + + if is_factor_configured_for_tenant( + all_available_first_factors=mt_recipe.all_available_first_factors, + first_factors=configured_first_factors, + factor_id=factor_id, + ): + return "OK" + + return "INVALID_FIRST_FACTOR_ERROR" + + +def is_factor_configured_for_tenant( + all_available_first_factors: List[str], + first_factors: List[str], + factor_id: str, +) -> bool: + configured_first_factors = [ + f + for f in first_factors + if f in all_available_first_factors or f not in FactorIds.__dict__.values() + ] + + return factor_id in configured_first_factors diff --git a/supertokens_python/recipe/multitenancy/api/implementation.py b/supertokens_python/recipe/multitenancy/api/implementation.py index 77b3ecaf4..565773e76 100644 --- a/supertokens_python/recipe/multitenancy/api/implementation.py +++ b/supertokens_python/recipe/multitenancy/api/implementation.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import importlib from typing import Any, Dict, Optional, Union, List from ..constants import DEFAULT_TENANT_ID @@ -35,6 +36,9 @@ async def login_methods_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[LoginMethodsGetOkResult, GeneralErrorResponse]: + module = importlib.import_module( + "supertokens_python.recipe.multifactorauth.utils" + ) from supertokens_python.recipe.thirdparty.providers.config_utils import ( merge_providers_from_core_and_static, find_and_create_provider_instance, @@ -52,7 +56,7 @@ async def login_methods_get( raise Exception("Tenant not found") provider_inputs_from_static = api_options.static_third_party_providers - provider_configs_from_core = tenant_config.third_party.providers + provider_configs_from_core = tenant_config.third_party_providers merged_providers = merge_providers_from_core_and_static( provider_configs_from_core, @@ -80,12 +84,37 @@ async def login_methods_get( ThirdPartyProvider(provider_instance.id, provider_instance.config.name) ) + first_factors: List[str] = [] + if tenant_config.first_factors is not None: + first_factors = tenant_config.first_factors + elif api_options.static_first_factors is not None: + first_factors = api_options.static_first_factors + else: + first_factors = list(set(api_options.all_available_first_factors)) + + valid_first_factors: List[str] = [] + for factor_id in first_factors: + valid_res = await module.is_valid_first_factor( + tenant_id, factor_id, user_context + ) + if valid_res == "OK": + valid_first_factors.append(factor_id) + if valid_res == "TENANT_NOT_FOUND_ERROR": + raise Exception("Tenant not found") + return LoginMethodsGetOkResult( email_password=LoginMethodEmailPassword( - tenant_config.emailpassword.enabled + enabled="emailpassword" in valid_first_factors + ), + passwordless=LoginMethodPasswordless( + enabled=any( + factor in valid_first_factors + for factor in ["otp-email", "otp-phone", "link-email", "link-phone"] + ) ), - passwordless=LoginMethodPasswordless(tenant_config.passwordless.enabled), third_party=LoginMethodThirdParty( - tenant_config.third_party.enabled, final_provider_list + enabled="thirdparty" in valid_first_factors, + providers=final_provider_list, ), + first_factors=valid_first_factors, ) diff --git a/supertokens_python/recipe/multitenancy/asyncio/__init__.py b/supertokens_python/recipe/multitenancy/asyncio/__init__.py index 1998cae3f..c9b1c6ce7 100644 --- a/supertokens_python/recipe/multitenancy/asyncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/asyncio/__init__.py @@ -14,11 +14,13 @@ from typing import Any, Dict, Union, Optional, TYPE_CHECKING +from supertokens_python.types import RecipeUserId + from ..interfaces import ( + AssociateUserToTenantNotAllowedError, TenantConfig, CreateOrUpdateTenantOkResult, DeleteTenantOkResult, - GetTenantOkResult, ListAllTenantsOkResult, CreateOrUpdateThirdPartyConfigOkResult, DeleteThirdPartyConfigOkResult, @@ -28,8 +30,8 @@ AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, DisassociateUserFromTenantOkResult, + TenantConfigCreateOrUpdate, ) -from ..recipe import MultitenancyRecipe if TYPE_CHECKING: from ..interfaces import ProviderConfig @@ -37,11 +39,13 @@ async def create_or_update_tenant( tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Optional[Dict[str, Any]] = None, ) -> CreateOrUpdateTenantOkResult: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.create_or_update_tenant( @@ -54,6 +58,8 @@ async def delete_tenant( ) -> DeleteTenantOkResult: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.delete_tenant(tenant_id, user_context) @@ -61,9 +67,11 @@ async def delete_tenant( async def get_tenant( tenant_id: str, user_context: Optional[Dict[str, Any]] = None -) -> Optional[GetTenantOkResult]: +) -> Optional[TenantConfig]: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.get_tenant(tenant_id, user_context) @@ -75,6 +83,8 @@ async def list_all_tenants( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.list_all_tenants(user_context) @@ -89,6 +99,8 @@ async def create_or_update_third_party_config( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.create_or_update_third_party_config( @@ -104,6 +116,8 @@ async def delete_third_party_config( if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.delete_third_party_config( @@ -113,7 +127,7 @@ async def delete_third_party_config( async def associate_user_to_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ) -> Union[ AssociateUserToTenantOkResult, @@ -121,27 +135,32 @@ async def associate_user_to_tenant( AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() return await recipe.recipe_implementation.associate_user_to_tenant( - tenant_id, user_id, user_context + tenant_id, recipe_user_id, user_context ) -async def dissociate_user_from_tenant( +async def disassociate_user_from_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ) -> DisassociateUserFromTenantOkResult: if user_context is None: user_context = {} + from ..recipe import MultitenancyRecipe + recipe = MultitenancyRecipe.get_instance() - return await recipe.recipe_implementation.dissociate_user_from_tenant( - tenant_id, user_id, user_context + return await recipe.recipe_implementation.disassociate_user_from_tenant( + tenant_id, recipe_user_id, user_context ) diff --git a/supertokens_python/recipe/multitenancy/interfaces.py b/supertokens_python/recipe/multitenancy/interfaces.py index 6bf03c9c8..ba8a27537 100644 --- a/supertokens_python/recipe/multitenancy/interfaces.py +++ b/supertokens_python/recipe/multitenancy/interfaces.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Union, Callable, Awaitable, Optional, List -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import APIResponse, GeneralErrorResponse, RecipeUserId if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse @@ -28,29 +28,82 @@ class TenantConfig: + # pylint: disable=dangerous-default-value def __init__( self, - email_password_enabled: Union[bool, None] = None, - passwordless_enabled: Union[bool, None] = None, - third_party_enabled: Union[bool, None] = None, - core_config: Union[Dict[str, Any], None] = None, + tenant_id: str = "", + third_party_providers: List[ProviderConfig] = [], + core_config: Dict[str, Any] = {}, + first_factors: Optional[List[str]] = None, + required_secondary_factors: Optional[List[str]] = None, ): - self.email_password_enabled = email_password_enabled - self.passwordless_enabled = passwordless_enabled - self.third_party_enabled = third_party_enabled + self.tenant_id = tenant_id self.core_config = core_config - - def to_json(self) -> Dict[str, Any]: - res: Dict[str, Any] = {} - if self.email_password_enabled is not None: - res["emailPasswordEnabled"] = self.email_password_enabled - if self.passwordless_enabled is not None: - res["passwordlessEnabled"] = self.passwordless_enabled - if self.third_party_enabled is not None: - res["thirdPartyEnabled"] = self.third_party_enabled - if self.core_config is not None: - res["coreConfig"] = self.core_config - return res + self.first_factors = first_factors + self.required_secondary_factors = required_secondary_factors + self.third_party_providers = third_party_providers + + @staticmethod + def from_json(json: Dict[str, Any]) -> TenantConfig: + return TenantConfig( + tenant_id=json.get("tenantId", ""), + third_party_providers=[ + ProviderConfig.from_json(provider) + for provider in json.get("thirdPartyProviders", []) + ], + core_config=json.get("coreConfig", {}), + first_factors=json.get("firstFactors", []), + required_secondary_factors=json.get("requiredSecondaryFactors", []), + ) + + +class TenantConfigCreateOrUpdate: + # pylint: disable=dangerous-default-value + def __init__( + self, + core_config: Dict[str, Any] = {}, + first_factors: Optional[List[str]] = [ + "NO_CHANGE" + ], # A default value here means that if the user does not set this, it will not make any change in the core. This is different from None, + # which means that the user wants to unset it in the core. + required_secondary_factors: Optional[List[str]] = [ + "NO_CHANGE" + ], # A default value here means that if the user does not set this, it will not make any change in the core. This is different from None, + # which means that the user wants to unset it in the core. + ): + self.core_config = core_config + self._first_factors = first_factors + self._required_secondary_factors = required_secondary_factors + + def is_first_factors_unchanged(self) -> bool: + return self._first_factors == ["NO_CHANGE"] + + def is_required_secondary_factors_unchanged(self) -> bool: + return self._required_secondary_factors == ["NO_CHANGE"] + + def get_first_factors_for_update(self) -> Optional[List[str]]: + if self._first_factors == ["NO_CHANGE"]: + raise Exception( + "First check if the value of first_factors is not NO_CHANGE" + ) + return self._first_factors + + def get_required_secondary_factors_for_update(self) -> Optional[List[str]]: + if self._required_secondary_factors == ["NO_CHANGE"]: + raise Exception( + "First check if the value of required_secondary_factors is not NO_CHANGE" + ) + return self._required_secondary_factors + + @staticmethod + def from_json(json: Dict[str, Any]) -> TenantConfigCreateOrUpdate: + return TenantConfigCreateOrUpdate( + core_config=json.get("coreConfig", {}), + first_factors=json.get("firstFactors", ["NO_CHANGE"]), + required_secondary_factors=json.get( + "requiredSecondaryFactors", ["NO_CHANGE"] + ), + ) class CreateOrUpdateTenantOkResult: @@ -67,80 +120,10 @@ def __init__(self, did_exist: bool): self.did_exist = did_exist -class EmailPasswordConfig: - def __init__(self, enabled: bool): - self.enabled = enabled - - def to_json(self): - return {"enabled": self.enabled} - - -class PasswordlessConfig: - def __init__(self, enabled: bool): - self.enabled = enabled - - def to_json(self): - return {"enabled": self.enabled} - - -class ThirdPartyConfig: - def __init__(self, enabled: bool, providers: List[ProviderConfig]): - self.enabled = enabled - self.providers = providers - - def to_json(self): - return { - "enabled": self.enabled, - "providers": [provider.to_json() for provider in self.providers], - } - - -class TenantConfigResponse: - def __init__( - self, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - core_config: Dict[str, Any], - ): - self.emailpassword = emailpassword - self.passwordless = passwordless - self.third_party = third_party - self.core_config = core_config - - -class GetTenantOkResult(TenantConfigResponse): - status = "OK" - - -class ListAllTenantsItem(TenantConfigResponse): - def __init__( - self, - tenant_id: str, - emailpassword: EmailPasswordConfig, - passwordless: PasswordlessConfig, - third_party: ThirdPartyConfig, - core_config: Dict[str, Any], - ): - super().__init__(emailpassword, passwordless, third_party, core_config) - self.tenant_id = tenant_id - - def to_json(self): - res = { - "tenantId": self.tenant_id, - "emailPassword": self.emailpassword.to_json(), - "passwordless": self.passwordless.to_json(), - "thirdParty": self.third_party.to_json(), - "coreConfig": self.core_config, - } - - return res - - class ListAllTenantsOkResult: status = "OK" - def __init__(self, tenants: List[ListAllTenantsItem]): + def __init__(self, tenants: List[TenantConfig]): self.tenants = tenants @@ -181,6 +164,14 @@ class AssociateUserToTenantThirdPartyUserAlreadyExistsError: status = "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR" +class AssociateUserToTenantNotAllowedError: + status = "ASSOCIATION_NOT_ALLOWED_ERROR" + + def __init__(self, reason: str): + self.status = "ASSOCIATION_NOT_ALLOWED_ERROR" + self.reason = reason + + class DisassociateUserFromTenantOkResult: status = "OK" @@ -202,7 +193,7 @@ async def get_tenant_id( async def create_or_update_tenant( self, tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: pass @@ -216,7 +207,7 @@ async def delete_tenant( @abstractmethod async def get_tenant( self, tenant_id: str, user_context: Dict[str, Any] - ) -> Optional[GetTenantOkResult]: + ) -> Optional[TenantConfig]: pass @abstractmethod @@ -250,7 +241,7 @@ async def delete_third_party_config( async def associate_user_to_tenant( self, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> Union[ AssociateUserToTenantOkResult, @@ -258,14 +249,15 @@ async def associate_user_to_tenant( AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: pass @abstractmethod - async def dissociate_user_from_tenant( + async def disassociate_user_from_tenant( self, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Dict[str, Any], ) -> DisassociateUserFromTenantOkResult: pass @@ -280,6 +272,8 @@ def __init__( config: MultitenancyConfig, recipe_implementation: RecipeInterface, static_third_party_providers: List[ProviderInput], + all_available_first_factors: List[str], + static_first_factors: Optional[List[str]], ): self.request = request self.response = response @@ -287,6 +281,8 @@ def __init__( self.config = config self.recipe_implementation = recipe_implementation self.static_third_party_providers = static_third_party_providers + self.static_first_factors = static_first_factors + self.all_available_first_factors = all_available_first_factors class ThirdPartyProvider: @@ -341,11 +337,13 @@ def __init__( email_password: LoginMethodEmailPassword, passwordless: LoginMethodPasswordless, third_party: LoginMethodThirdParty, + first_factors: List[str], ): self.status = "OK" self.email_password = email_password self.passwordless = passwordless self.third_party = third_party + self.first_factors = first_factors def to_json(self) -> Dict[str, Any]: return { @@ -353,6 +351,7 @@ def to_json(self) -> Dict[str, Any]: "emailPassword": self.email_password.to_json(), "passwordless": self.passwordless.to_json(), "thirdParty": self.third_party.to_json(), + "firstFactors": self.first_factors, } diff --git a/supertokens_python/recipe/multitenancy/recipe.py b/supertokens_python/recipe/multitenancy/recipe.py index 9cc7aede8..8b0fee9e4 100644 --- a/supertokens_python/recipe/multitenancy/recipe.py +++ b/supertokens_python/recipe/multitenancy/recipe.py @@ -21,6 +21,7 @@ PrimitiveArrayClaim, ) from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.types import RecipeUserId from ...post_init_callbacks import PostSTInitCallbacks @@ -92,6 +93,8 @@ def __init__( ) RecipeModule.get_tenant_id = recipe_implementation.get_tenant_id + self.static_first_factors: Optional[List[str]] = None + self.all_available_first_factors: List[str] = [] def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, MultitenancyError) @@ -123,6 +126,8 @@ async def handle_api_request( self.config, self.recipe_implementation, self.static_third_party_providers, + self.all_available_first_factors, + self.static_first_factors, ) return await handle_login_methods_api( self.api_implementation, @@ -205,7 +210,11 @@ def __init__(self): default_max_age_in_sec = 60 * 60 async def fetch_value( - _: str, tenant_id: str, user_context: Dict[str, Any] + _user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> Optional[List[str]]: recipe = MultitenancyRecipe.get_instance() diff --git a/supertokens_python/recipe/multitenancy/recipe_implementation.py b/supertokens_python/recipe/multitenancy/recipe_implementation.py index 30a1105b4..90578bacb 100644 --- a/supertokens_python/recipe/multitenancy/recipe_implementation.py +++ b/supertokens_python/recipe/multitenancy/recipe_implementation.py @@ -22,21 +22,18 @@ AssociateUserToTenantThirdPartyUserAlreadyExistsError, DisassociateUserFromTenantOkResult, ) +from supertokens_python.types import RecipeUserId from .interfaces import ( + AssociateUserToTenantNotAllowedError, RecipeInterface, TenantConfig, CreateOrUpdateTenantOkResult, DeleteTenantOkResult, - TenantConfigResponse, - GetTenantOkResult, - EmailPasswordConfig, - PasswordlessConfig, - ThirdPartyConfig, ListAllTenantsOkResult, CreateOrUpdateThirdPartyConfigOkResult, DeleteThirdPartyConfigOkResult, - ListAllTenantsItem, + TenantConfigCreateOrUpdate, ) if TYPE_CHECKING: @@ -48,7 +45,7 @@ from .constants import DEFAULT_TENANT_ID -def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: +def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfig: from supertokens_python.recipe.thirdparty.provider import ( UserInfoMap, UserFields, @@ -109,14 +106,12 @@ def parse_tenant_config(tenant: Dict[str, Any]) -> TenantConfigResponse: ) ) - return TenantConfigResponse( - emailpassword=EmailPasswordConfig(tenant["emailPassword"]["enabled"]), - passwordless=PasswordlessConfig(tenant["passwordless"]["enabled"]), - third_party=ThirdPartyConfig( - tenant["thirdParty"]["enabled"], - providers, - ), + return TenantConfig( + tenant_id=tenant["tenantId"], + third_party_providers=providers, core_config=tenant["coreConfig"], + first_factors=tenant.get("firstFactors"), + required_secondary_factors=tenant.get("requiredSecondaryFactors"), ) @@ -134,15 +129,24 @@ async def get_tenant_id( async def create_or_update_tenant( self, tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Dict[str, Any], ) -> CreateOrUpdateTenantOkResult: + json_body: Dict[str, Any] = { + "tenantId": tenant_id, + } + if config is not None: + if not config.is_first_factors_unchanged(): + json_body["firstFactors"] = config.get_first_factors_for_update() + if not config.is_required_secondary_factors_unchanged(): + json_body[ + "requiredSecondaryFactors" + ] = config.get_required_secondary_factors_for_update() + json_body["coreConfig"] = config.core_config + response = await self.querier.send_put_request( - NormalisedURLPath("/recipe/multitenancy/tenant"), - { - "tenantId": tenant_id, - **(config.to_json() if config is not None else {}), - }, + NormalisedURLPath("/recipe/multitenancy/tenant/v2"), + json_body, user_context=user_context, ) return CreateOrUpdateTenantOkResult( @@ -163,10 +167,10 @@ async def delete_tenant( async def get_tenant( self, tenant_id: Optional[str], user_context: Dict[str, Any] - ) -> Optional[GetTenantOkResult]: + ) -> Optional[TenantConfig]: res = await self.querier.send_get_request( NormalisedURLPath( - f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant" + f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/v2" ), None, user_context=user_context, @@ -177,34 +181,22 @@ async def get_tenant( tenant_config = parse_tenant_config(res) - return GetTenantOkResult( - emailpassword=tenant_config.emailpassword, - passwordless=tenant_config.passwordless, - third_party=tenant_config.third_party, - core_config=tenant_config.core_config, - ) + return tenant_config async def list_all_tenants( self, user_context: Dict[str, Any] ) -> ListAllTenantsOkResult: response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/multitenancy/tenant/list"), + NormalisedURLPath("/recipe/multitenancy/tenant/list/v2"), {}, user_context=user_context, ) - tenant_items: List[ListAllTenantsItem] = [] + tenant_items: List[TenantConfig] = [] for tenant in response["tenants"]: config = parse_tenant_config(tenant) - item = ListAllTenantsItem( - tenant["tenantId"], - config.emailpassword, - config.passwordless, - config.third_party, - config.core_config, - ) - tenant_items.append(item) + tenant_items.append(config) return ListAllTenantsOkResult( tenants=tenant_items, @@ -253,20 +245,24 @@ async def delete_third_party_config( ) async def associate_user_to_tenant( - self, tenant_id: Optional[str], user_id: str, user_context: Dict[str, Any] + self, + tenant_id: Optional[str], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> Union[ AssociateUserToTenantOkResult, AssociateUserToTenantUnknownUserIdError, AssociateUserToTenantEmailAlreadyExistsError, AssociateUserToTenantPhoneNumberAlreadyExistsError, AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantNotAllowedError, ]: response = await self.querier.send_post_request( NormalisedURLPath( f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user" ), { - "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), }, user_context=user_context, ) @@ -289,18 +285,23 @@ async def associate_user_to_tenant( == AssociateUserToTenantThirdPartyUserAlreadyExistsError.status ): return AssociateUserToTenantThirdPartyUserAlreadyExistsError() + if response["status"] == AssociateUserToTenantNotAllowedError.status: + return AssociateUserToTenantNotAllowedError(response["reason"]) raise Exception("Should never come here") - async def dissociate_user_from_tenant( - self, tenant_id: Optional[str], user_id: str, user_context: Dict[str, Any] + async def disassociate_user_from_tenant( + self, + tenant_id: Optional[str], + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], ) -> DisassociateUserFromTenantOkResult: response = await self.querier.send_post_request( NormalisedURLPath( f"{tenant_id or DEFAULT_TENANT_ID}/recipe/multitenancy/tenant/user/remove" ), { - "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), }, user_context=user_context, ) diff --git a/supertokens_python/recipe/multitenancy/syncio/__init__.py b/supertokens_python/recipe/multitenancy/syncio/__init__.py index 27025f476..5448f2612 100644 --- a/supertokens_python/recipe/multitenancy/syncio/__init__.py +++ b/supertokens_python/recipe/multitenancy/syncio/__init__.py @@ -15,14 +15,16 @@ from typing import Any, Dict, Optional, TYPE_CHECKING from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.recipe.multitenancy.interfaces import TenantConfigCreateOrUpdate +from supertokens_python.types import RecipeUserId if TYPE_CHECKING: - from ..interfaces import TenantConfig, ProviderConfig + from ..interfaces import ProviderConfig def create_or_update_tenant( tenant_id: str, - config: Optional[TenantConfig], + config: Optional[TenantConfigCreateOrUpdate], user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: @@ -95,7 +97,7 @@ def delete_third_party_config( def associate_user_to_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: @@ -103,19 +105,19 @@ def associate_user_to_tenant( from supertokens_python.recipe.multitenancy.asyncio import associate_user_to_tenant - return sync(associate_user_to_tenant(tenant_id, user_id, user_context)) + return sync(associate_user_to_tenant(tenant_id, recipe_user_id, user_context)) -def dissociate_user_from_tenant( +def disassociate_user_from_tenant( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, user_context: Optional[Dict[str, Any]] = None, ): if user_context is None: user_context = {} from supertokens_python.recipe.multitenancy.asyncio import ( - dissociate_user_from_tenant, + disassociate_user_from_tenant, ) - return sync(dissociate_user_from_tenant(tenant_id, user_id, user_context)) + return sync(disassociate_user_from_tenant(tenant_id, recipe_user_id, user_context)) diff --git a/supertokens_python/recipe/passwordless/api/consume_code.py b/supertokens_python/recipe/passwordless/api/consume_code.py index 442c6ba7b..333194200 100644 --- a/supertokens_python/recipe/passwordless/api/consume_code.py +++ b/supertokens_python/recipe/passwordless/api/consume_code.py @@ -12,9 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.exceptions import raise_bad_input_exception -from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions -from supertokens_python.utils import send_200_response +from supertokens_python.recipe.passwordless.interfaces import ( + APIInterface, + APIOptions, + ConsumeCodePostOkResult, +) +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def consume_code( @@ -56,13 +65,44 @@ async def consume_code( pre_auth_session_id = body["preAuthSessionId"] + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.consume_code_post( pre_auth_session_id, user_input_code, device_id, link_code, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, ) + + if isinstance(result, ConsumeCodePostOkResult): + return send_200_response( + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + }, + api_options.response, + ) + return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/api/create_code.py b/supertokens_python/recipe/passwordless/api/create_code.py index c4c162f86..2c57519df 100644 --- a/supertokens_python/recipe/passwordless/api/create_code.py +++ b/supertokens_python/recipe/passwordless/api/create_code.py @@ -14,7 +14,8 @@ from typing import Union, Any, Dict import phonenumbers # type: ignore -from phonenumbers import format_number, parse # type: ignore +from phonenumbers import format_number, parse +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed # type: ignore from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions from supertokens_python.recipe.passwordless.utils import ( @@ -23,7 +24,10 @@ ContactPhoneOnlyConfig, ) from supertokens_python.types import GeneralErrorResponse -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def create_code( @@ -109,11 +113,26 @@ async def create_code( except Exception: phone_number = phone_number.strip() + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.create_code_post( email=email, phone_number=phone_number, + session=session, tenant_id=tenant_id, api_options=api_options, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/api/implementation.py b/supertokens_python/recipe/passwordless/api/implementation.py index 0a6f77764..3aa530f26 100644 --- a/supertokens_python/recipe/passwordless/api/implementation.py +++ b/supertokens_python/recipe/passwordless/api/implementation.py @@ -11,14 +11,31 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + OkResponse, + PostAuthChecksOkResponse, + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + check_auth_type_and_linking_status, + filter_out_invalid_first_factors_or_throw_if_all_are_invalid, + get_authenticating_user_and_add_to_current_tenant_if_required, + post_auth_checks, + pre_auth_checks, +) from supertokens_python.logger import log_debug_message +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.passwordless.interfaces import ( APIInterface, APIOptions, + CheckCodeOkResult, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeOkResult, ConsumeCodePostExpiredUserInputCodeError, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostOkResult, @@ -32,6 +49,7 @@ PhoneNumberExistsGetOkResult, ResendCodePostOkResult, ResendCodePostRestartFlowError, + SignInUpPostNotAllowedResponse, ) from supertokens_python.recipe.passwordless.types import ( PasswordlessLoginSMSTemplateVars, @@ -40,33 +58,209 @@ ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, + get_enabled_pwless_factors, +) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.types import ( + AccountInfo, + User, + GeneralErrorResponse, + LoginMethod, + RecipeUserId, ) -from supertokens_python.recipe.session.asyncio import create_new_session -from supertokens_python.types import GeneralErrorResponse from ...emailverification import EmailVerificationRecipe from ...emailverification.interfaces import CreateEmailVerificationTokenOkResult +class PasswordlessUserResult: + user: User + login_method: Union[LoginMethod, None] + + def __init__(self, user: User, login_method: Union[LoginMethod, None]): + self.user = user + self.login_method = login_method + + +async def get_passwordless_user_by_account_info( + tenant_id: str, + user_context: Dict[str, Any], + account_info: AccountInfo, +) -> Optional[PasswordlessUserResult]: + existing_users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=account_info, + do_union_of_account_info=False, + user_context=user_context, + ) + log_debug_message( + f"get_passwordless_user_by_account_info got {len(existing_users)} from core resp {account_info}" + ) + + users_with_matching_login_methods = [ + PasswordlessUserResult( + user=user, + login_method=next( + ( + lm + for lm in user.login_methods + if lm.recipe_id == "passwordless" + and ( + lm.has_same_email_as(account_info.email) + or lm.has_same_phone_number_as(account_info.phone_number) + ) + ), + None, + ), + ) + for user in existing_users + ] + users_with_matching_login_methods = [ + user_data + for user_data in users_with_matching_login_methods + if user_data.login_method is not None + ] + + log_debug_message( + f"get_passwordless_user_by_account_info {len(users_with_matching_login_methods)} has matching login methods" + ) + + if len(users_with_matching_login_methods) > 1: + raise Exception( + "This should never happen: multiple users exist matching the accountInfo in passwordless createCode" + ) + + if len(users_with_matching_login_methods) == 0: + return None + + return users_with_matching_login_methods[0] + + class APIImplementation(APIInterface): async def create_code_post( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[CreateCodePostOkResult, GeneralErrorResponse]: + ) -> Union[ + CreateCodePostOkResult, SignInUpPostNotAllowedResponse, GeneralErrorResponse + ]: + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_002)", + "LINKING_TO_SESSION_USER_FAILED": { + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_019)", + }, + } + + account_info = AccountInfo( + email=email, + phone_number=phone_number, + ) + + user_with_matching_login_method = await get_passwordless_user_by_account_info( + tenant_id, user_context, account_info + ) + + factor_ids = [] + if session is not None: + factor_ids = [ + FactorIds.OTP_EMAIL if email is not None else FactorIds.OTP_PHONE + ] + else: + factor_ids = get_enabled_pwless_factors(api_options.config) + if email is not None: + factor_ids = [ + f + for f in factor_ids + if f in [FactorIds.OTP_EMAIL, FactorIds.LINK_EMAIL] + ] + else: + factor_ids = [ + f + for f in factor_ids + if f in [FactorIds.OTP_PHONE, FactorIds.LINK_PHONE] + ] + + is_verified_input = True + if user_with_matching_login_method is not None: + assert user_with_matching_login_method.login_method is not None + is_verified_input = user_with_matching_login_method.login_method.verified + + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=account_info.email, + phone_number=account_info.phone_number, + ), + is_sign_up=user_with_matching_login_method is None, + authenticating_user=( + user_with_matching_login_method.user + if user_with_matching_login_method + else None + ), + is_verified=is_verified_input, + sign_in_verifies_login_method=True, + skip_session_user_update_in_core=True, + tenant_id=tenant_id, + factor_ids=factor_ids, + user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + raise Exception("Should never come here") + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpPostNotAllowedResponse(reason=reason) + user_input_code = None if api_options.config.get_custom_user_input_code is not None: user_input_code = await api_options.config.get_custom_user_input_code( tenant_id, user_context ) + + user_input_code_input = user_input_code + if api_options.config.get_custom_user_input_code is not None: + user_input_code_input = await api_options.config.get_custom_user_input_code( + tenant_id, user_context + ) response = await api_options.recipe_implementation.create_code( - email, phone_number, user_input_code, tenant_id, user_context + email=account_info.email, + phone_number=account_info.phone_number, + user_input_code=user_input_code_input, + tenant_id=tenant_id, + user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) + magic_link = None user_input_code = None flow_type = api_options.config.flow_type + + if all( + _id.startswith("link") for _id in pre_auth_checks_result.valid_factor_ids + ): + flow_type = "MAGIC_LINK" + elif all( + _id.startswith("otp") for _id in pre_auth_checks_result.valid_factor_ids + ): + flow_type = "USER_INPUT_CODE" + else: + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + if flow_type in ("MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"): magic_link = ( api_options.app_info.get_origin( @@ -99,6 +293,7 @@ async def create_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=pre_auth_checks_result.is_first_factor, ) await api_options.email_delivery.ingredient_interface_impl.send_email( passwordless_email_delivery_input, user_context @@ -117,6 +312,7 @@ async def create_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=pre_auth_checks_result.is_first_factor, ) await api_options.sms_delivery.ingredient_interface_impl.send_sms( sms_input, user_context @@ -130,6 +326,8 @@ async def resend_code_post( self, device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -139,16 +337,48 @@ async def resend_code_post( device_info = await api_options.recipe_implementation.list_codes_by_device_id( device_id=device_id, tenant_id=tenant_id, user_context=user_context ) + if device_info is None: return ResendCodePostRestartFlowError() + if ( - isinstance(api_options.config.contact_config, ContactEmailOnlyConfig) - and device_info.email is None - ) or ( - isinstance(api_options.config.contact_config, ContactPhoneOnlyConfig) + api_options.config.contact_config.contact_method == "PHONE" and device_info.phone_number is None + ) or ( + api_options.config.contact_config.contact_method == "EMAIL" + and device_info.email is None ): return ResendCodePostRestartFlowError() + + user_with_matching_login_method = await get_passwordless_user_by_account_info( + tenant_id=tenant_id, + user_context=user_context, + account_info=AccountInfo( + email=device_info.email, + phone_number=device_info.phone_number, + ), + ) + + auth_type_info = await check_auth_type_and_linking_status( + session=session, + account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=device_info.email, + phone_number=device_info.phone_number, + ), + input_user=( + user_with_matching_login_method.user + if user_with_matching_login_method + else None + ), + skip_session_user_update_in_core=True, + user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if auth_type_info.status == "LINKING_TO_SESSION_USER_FAILED": + return ResendCodePostRestartFlowError() + number_of_tries_to_create_new_code = 0 while True: number_of_tries_to_create_new_code += 1 @@ -157,14 +387,22 @@ async def resend_code_post( user_input_code = await api_options.config.get_custom_user_input_code( tenant_id, user_context ) + user_input_code_input = user_input_code + if api_options.config.get_custom_user_input_code is not None: + user_input_code_input = ( + await api_options.config.get_custom_user_input_code( + tenant_id, user_context + ) + ) response = ( await api_options.recipe_implementation.create_new_code_for_device( device_id=device_id, - user_input_code=user_input_code, + user_input_code=user_input_code_input, tenant_id=tenant_id, user_context=user_context, ) ) + if isinstance( response, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError ): @@ -177,7 +415,30 @@ async def resend_code_post( if isinstance(response, CreateNewCodeForDeviceOkResult): magic_link = None user_input_code = None + + factor_ids = [] + if session is not None: + factor_ids = [ + ( + FactorIds.OTP_EMAIL + if device_info.email is not None + else FactorIds.OTP_PHONE + ) + ] + else: + factor_ids = get_enabled_pwless_factors(api_options.config) + factor_ids = await filter_out_invalid_first_factors_or_throw_if_all_are_invalid( + factor_ids, tenant_id, False, user_context + ) + flow_type = api_options.config.flow_type + if all(id.startswith("link") for id in factor_ids): + flow_type = "MAGIC_LINK" + elif all(id.startswith("otp") for id in factor_ids): + flow_type = "USER_INPUT_CODE" + else: + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + if flow_type in ("MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"): magic_link = ( api_options.app_info.get_origin( @@ -195,20 +456,31 @@ async def resend_code_post( if flow_type in ("USER_INPUT_CODE", "USER_INPUT_CODE_AND_MAGIC_LINK"): user_input_code = response.user_input_code - if isinstance( - api_options.config.contact_config, ContactEmailOnlyConfig - ) or ( - isinstance( - api_options.config.contact_config, ContactEmailOrPhoneConfig - ) - and device_info.email is not None + if api_options.config.contact_config.contact_method == "PHONE" or ( + api_options.config.contact_config.contact_method == "EMAIL_OR_PHONE" + and device_info.phone_number is not None ): - if device_info.email is None: - raise Exception("Should never come here") - + log_debug_message( + "Sending passwordless login SMS to %s", device_info.phone_number + ) + assert device_info.phone_number is not None + sms_input = PasswordlessLoginSMSTemplateVars( + phone_number=device_info.phone_number, + user_input_code=user_input_code, + url_with_link_code=magic_link, + code_life_time=response.code_life_time, + pre_auth_session_id=response.pre_auth_session_id, + tenant_id=tenant_id, + is_first_factor=auth_type_info.is_first_factor, + ) + await api_options.sms_delivery.ingredient_interface_impl.send_sms( + sms_input, user_context + ) + else: log_debug_message( "Sending passwordless login email to %s", device_info.email ) + assert device_info.email is not None passwordless_email_delivery_input = ( PasswordlessLoginEmailTemplateVars( email=device_info.email, @@ -217,32 +489,15 @@ async def resend_code_post( code_life_time=response.code_life_time, pre_auth_session_id=response.pre_auth_session_id, tenant_id=tenant_id, + is_first_factor=auth_type_info.is_first_factor, ) ) await api_options.email_delivery.ingredient_interface_impl.send_email( passwordless_email_delivery_input, user_context ) - elif isinstance( - api_options.config.contact_config, - (ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig), - ): - if device_info.phone_number is None: - raise Exception("Should never come here") - log_debug_message( - "Sending passwordless login SMS to %s", device_info.phone_number - ) - sms_input = PasswordlessLoginSMSTemplateVars( - phone_number=device_info.phone_number, - user_input_code=user_input_code, - url_with_link_code=magic_link, - code_life_time=response.code_life_time, - pre_auth_session_id=response.pre_auth_session_id, - tenant_id=tenant_id, - ) - await api_options.sms_delivery.ingredient_interface_impl.send_sms( - sms_input, user_context - ) + return ResendCodePostOkResult() + return ResendCodePostRestartFlowError() async def consume_code_post( @@ -251,65 +506,234 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ ConsumeCodePostOkResult, ConsumeCodePostRestartFlowError, + GeneralErrorResponse, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostExpiredUserInputCodeError, - GeneralErrorResponse, + SignInUpPostNotAllowedResponse, ]: + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_002)", + "SIGN_IN_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_003)", + "LINKING_TO_SESSION_USER_FAILED": { + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_017)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_018)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_019)", + }, + } + + device_info = ( + await api_options.recipe_implementation.list_codes_by_pre_auth_session_id( + tenant_id=tenant_id, + pre_auth_session_id=pre_auth_session_id, + user_context=user_context, + ) + ) + + if not device_info: + return ConsumeCodePostRestartFlowError() + + recipe_id = "passwordless" + account_info = AccountInfo( + phone_number=device_info.phone_number, email=device_info.email + ) + + async def check_credentials(_: str): + nonlocal check_credentials_response + if check_credentials_response is None: + check_credentials_response = ( + await api_options.recipe_implementation.check_code( + pre_auth_session_id=pre_auth_session_id, + device_id=device_id, + user_input_code=user_input_code, + link_code=link_code, + tenant_id=tenant_id, + user_context=user_context, + ) + ) + return isinstance(check_credentials_response, CheckCodeOkResult) + + check_credentials_response = None + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + email=account_info.email, + phone_number=account_info.phone_number, + third_party=None, + recipe_id=recipe_id, + user_context=user_context, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials, + ) + ) + + ev_instance = EmailVerificationRecipe.get_instance_optional() + if account_info.email and session and ev_instance: + session_user = await get_user(session.get_user_id(), user_context) + if session_user is None: + raise UnauthorisedError( + "Session user not found", + ) + + login_method = next( + ( + lm + for lm in session_user.login_methods + if lm.recipe_user_id.get_as_string() + == session.get_recipe_user_id().get_as_string() + ), + None, + ) + if login_method is None: + raise UnauthorisedError( + "Session user and session recipeUserId is inconsistent", + ) + + if ( + login_method.has_same_email_as(account_info.email) + and not login_method.verified + ): + if await check_credentials(tenant_id): + token_response = await ev_instance.recipe_implementation.create_email_verification_token( + tenant_id=tenant_id, + recipe_user_id=login_method.recipe_user_id, + email=account_info.email, + user_context=user_context, + ) + if isinstance(token_response, CreateEmailVerificationTokenOkResult): + await ev_instance.recipe_implementation.verify_email_using_token( + tenant_id=tenant_id, + token=token_response.token, + attempt_account_linking=False, + user_context=user_context, + ) + + factor_id = ( + FactorIds.OTP_EMAIL + if device_info.email and user_input_code + else ( + FactorIds.LINK_EMAIL + if device_info.email + else (FactorIds.OTP_PHONE if user_input_code else FactorIds.LINK_PHONE) + ) + ) + + is_sign_up = authenticating_user is None + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id="passwordless", + email=device_info.email, + phone_number=device_info.phone_number, + ), + factor_ids=[factor_id], + authenticating_user=( + authenticating_user.user if authenticating_user else None + ), + is_sign_up=is_sign_up, + is_verified=( + authenticating_user.login_method.verified + if authenticating_user and authenticating_user.login_method + else True + ), + sign_in_verifies_login_method=True, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpPostNotAllowedResponse(reason=reason) + + if check_credentials_response is not None: + if not isinstance(check_credentials_response, CheckCodeOkResult): + return check_credentials_response + response = await api_options.recipe_implementation.consume_code( pre_auth_session_id=pre_auth_session_id, - user_input_code=user_input_code, device_id=device_id, + user_input_code=user_input_code, link_code=link_code, + session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) - if isinstance(response, ConsumeCodeExpiredUserInputCodeError): - return ConsumeCodePostExpiredUserInputCodeError( - failed_code_input_attempt_count=response.failed_code_input_attempt_count, - maximum_code_input_attempts=response.maximum_code_input_attempts, - ) + if isinstance(response, ConsumeCodeRestartFlowError): + return ConsumeCodePostRestartFlowError() if isinstance(response, ConsumeCodeIncorrectUserInputCodeError): return ConsumeCodePostIncorrectUserInputCodeError( - failed_code_input_attempt_count=response.failed_code_input_attempt_count, - maximum_code_input_attempts=response.maximum_code_input_attempts, + response.failed_code_input_attempt_count, + response.maximum_code_input_attempts, ) - if isinstance(response, ConsumeCodeRestartFlowError): - return ConsumeCodePostRestartFlowError() - - user = response.user - - if user.email is not None: - ev_instance = EmailVerificationRecipe.get_instance_optional() - if ev_instance is not None: - token_response = await ev_instance.recipe_implementation.create_email_verification_token( - user.user_id, user.email, tenant_id, user_context - ) + if isinstance(response, ConsumeCodeExpiredUserInputCodeError): + return ConsumeCodePostExpiredUserInputCodeError( + response.failed_code_input_attempt_count, + response.maximum_code_input_attempts, + ) + if not isinstance(response, ConsumeCodeOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[response.reason] + return SignInUpPostNotAllowedResponse(reason=reason) - if isinstance(token_response, CreateEmailVerificationTokenOkResult): - await ev_instance.recipe_implementation.verify_email_using_token( - token_response.token, tenant_id, user_context - ) + authenticating_user_input: User + if response.user: + authenticating_user_input = response.user + elif authenticating_user: + authenticating_user_input = authenticating_user.user + else: + raise Exception("Should never come here") + recipe_user_id_input: RecipeUserId + if response.recipe_user_id: + recipe_user_id_input = response.recipe_user_id + elif authenticating_user: + assert authenticating_user.login_method is not None + recipe_user_id_input = authenticating_user.login_method.recipe_user_id + else: + raise Exception("Should never come here") - session = await create_new_session( - request=api_options.request, + post_auth_checks_result = await post_auth_checks( + factor_id=factor_id, + is_sign_up=is_sign_up, + authenticated_user=authenticating_user_input, + recipe_user_id=recipe_user_id_input, tenant_id=tenant_id, - user_id=user.user_id, - access_token_payload={}, - session_data_in_database={}, user_context=user_context, + session=session, + request=api_options.request, ) + if not isinstance(post_auth_checks_result, PostAuthChecksOkResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpPostNotAllowedResponse(reason) + return ConsumeCodePostOkResult( - created_new_user=response.created_new_user, - user=response.user, - session=session, + created_new_recipe_user=response.created_new_recipe_user, + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, ) async def email_exists_get( @@ -319,10 +743,21 @@ async def email_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[EmailExistsGetOkResult, GeneralErrorResponse]: - response = await api_options.recipe_implementation.get_user_by_email( - email, tenant_id, user_context + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(email=email), + do_union_of_account_info=False, + user_context=user_context, ) - return EmailExistsGetOkResult(exists=response is not None) + user_exists = any( + any( + lm.recipe_id == "passwordless" and lm.has_same_email_as(email) + for lm in u.login_methods + ) + for u in users + ) + + return EmailExistsGetOkResult(exists=user_exists) async def phone_number_exists_get( self, @@ -331,7 +766,10 @@ async def phone_number_exists_get( api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[PhoneNumberExistsGetOkResult, GeneralErrorResponse]: - response = await api_options.recipe_implementation.get_user_by_phone_number( - phone_number, tenant_id, user_context + users = await AccountLinkingRecipe.get_instance().recipe_implementation.list_users_by_account_info( + tenant_id=tenant_id, + account_info=AccountInfo(phone_number=phone_number), + do_union_of_account_info=False, + user_context=user_context, ) - return PhoneNumberExistsGetOkResult(exists=response is not None) + return PhoneNumberExistsGetOkResult(exists=len(users) > 0) diff --git a/supertokens_python/recipe/passwordless/api/resend_code.py b/supertokens_python/recipe/passwordless/api/resend_code.py index 8afd885db..8ee4cb762 100644 --- a/supertokens_python/recipe/passwordless/api/resend_code.py +++ b/supertokens_python/recipe/passwordless/api/resend_code.py @@ -12,9 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. from typing import Any, Dict +from supertokens_python.auth_utils import load_session_in_auth_api_if_needed from supertokens_python.exceptions import raise_bad_input_exception from supertokens_python.recipe.passwordless.interfaces import APIInterface, APIOptions -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def resend_code( @@ -39,7 +43,26 @@ async def resend_code( pre_auth_session_id = body["preAuthSessionId"] device_id = body["deviceId"] + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/passwordless/asyncio/__init__.py b/supertokens_python/recipe/passwordless/asyncio/__init__.py index 4632f6109..8f29be622 100644 --- a/supertokens_python/recipe/passwordless/asyncio/__init__.py +++ b/supertokens_python/recipe/passwordless/asyncio/__init__.py @@ -11,10 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from supertokens_python import get_request_from_user_context +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.passwordless.interfaces import ( + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeOkResult, @@ -23,8 +28,8 @@ CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, RevokeAllCodesOkResult, RevokeCodeOkResult, UpdateUserEmailAlreadyExistsError, @@ -37,8 +42,9 @@ DeviceType, EmailTemplateVars, SMSTemplateVars, - User, ) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import RecipeUserId async def create_code( @@ -46,6 +52,7 @@ async def create_code( email: Union[None, str] = None, phone_number: Union[None, str] = None, user_input_code: Union[None, str] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> CreateCodeOkResult: if user_context is None: @@ -55,7 +62,9 @@ async def create_code( phone_number=phone_number, user_input_code=user_input_code, tenant_id=tenant_id, + session=session, user_context=user_context, + should_try_linking_with_session_user=session is not None, ) @@ -85,12 +94,14 @@ async def consume_code( user_input_code: Union[str, None] = None, device_id: Union[str, None] = None, link_code: Union[str, None] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ConsumeCodeOkResult, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: if user_context is None: user_context = {} @@ -100,48 +111,14 @@ async def consume_code( device_id=device_id, link_code=link_code, tenant_id=tenant_id, - user_context=user_context, - ) - - -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id=user_id, user_context=user_context - ) - - -async def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return ( - await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_email( - email=email, - tenant_id=tenant_id, - user_context=user_context, - ) - ) - - -async def get_user_by_phone_number( - tenant_id: str, phone_number: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await PasswordlessRecipe.get_instance().recipe_implementation.get_user_by_phone_number( - phone_number=phone_number, - tenant_id=tenant_id, + session=session, + should_try_linking_with_session_user=session is not None, user_context=user_context, ) async def update_user( - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None] = None, phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -150,11 +127,13 @@ async def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: if user_context is None: user_context = {} return await PasswordlessRecipe.get_instance().recipe_implementation.update_user( - user_id=user_id, + recipe_user_id=recipe_user_id, email=email, phone_number=phone_number, user_context=user_context, @@ -162,22 +141,51 @@ async def update_user( async def delete_email_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: if user_context is None: user_context = {} return await PasswordlessRecipe.get_instance().recipe_implementation.delete_email_for_user( - user_id=user_id, user_context=user_context + recipe_user_id=recipe_user_id, + user_context=user_context, ) async def delete_phone_number_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: if user_context is None: user_context = {} return await PasswordlessRecipe.get_instance().recipe_implementation.delete_phone_number_for_user( - user_id=user_id, user_context=user_context + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + +async def check_code( + tenant_id: str, + pre_auth_session_id: str, + user_input_code: Union[str, None] = None, + device_id: Union[str, None] = None, + link_code: Union[str, None] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, +]: + if user_context is None: + user_context = {} + return await PasswordlessRecipe.get_instance().recipe_implementation.check_code( + pre_auth_session_id=pre_auth_session_id, + user_input_code=user_input_code, + device_id=device_id, + link_code=link_code, + tenant_id=tenant_id, + user_context=user_context, ) @@ -281,6 +289,7 @@ async def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: if user_context is None: @@ -290,6 +299,7 @@ async def signinup( phone_number=phone_number, tenant_id=tenant_id, user_context=user_context, + session=session, ) diff --git a/supertokens_python/recipe/passwordless/interfaces.py b/supertokens_python/recipe/passwordless/interfaces.py index 838890d7e..877539069 100644 --- a/supertokens_python/recipe/passwordless/interfaces.py +++ b/supertokens_python/recipe/passwordless/interfaces.py @@ -14,14 +14,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from typing_extensions import Literal +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.recipe.session import SessionContainer -from supertokens_python.types import APIResponse, GeneralErrorResponse +from supertokens_python.types import ( + APIResponse, + User, + GeneralErrorResponse, + RecipeUserId, +) from ...supertokens import AppInfo @@ -31,7 +37,6 @@ PasswordlessLoginEmailTemplateVars, PasswordlessLoginSMSTemplateVars, SMSDeliveryIngredient, - User, ) from .utils import PasswordlessConfig @@ -84,10 +89,49 @@ class CreateNewCodeForDeviceUserInputCodeAlreadyUsedError: pass +class ConsumedDevice: + def __init__( + self, + pre_auth_session_id: str, + failed_code_input_attempt_count: int, + email: Optional[str] = None, + phone_number: Optional[str] = None, + ): + self.pre_auth_session_id = pre_auth_session_id + self.failed_code_input_attempt_count = failed_code_input_attempt_count + self.email = email + self.phone_number = phone_number + + @staticmethod + def from_json(json: Dict[str, Any]) -> ConsumedDevice: + return ConsumedDevice( + pre_auth_session_id=json["preAuthSessionId"], + failed_code_input_attempt_count=json["failedCodeInputAttemptCount"], + email=json["email"] if "email" in json else None, + phone_number=json["phoneNumber"] if "phoneNumber" in json else None, + ) + + def to_json(self) -> Dict[str, Any]: + return { + "preAuthSessionId": self.pre_auth_session_id, + "failedCodeInputAttemptCount": self.failed_code_input_attempt_count, + "email": self.email, + "phoneNumber": self.phone_number, + } + + class ConsumeCodeOkResult: - def __init__(self, created_new_user: bool, user: User): - self.created_new_user = created_new_user + def __init__( + self, + created_new_recipe_user: bool, + user: User, + recipe_user_id: RecipeUserId, + consumed_device: ConsumedDevice, + ): + self.created_new_recipe_user = created_new_recipe_user self.user = user + self.recipe_user_id = recipe_user_id + self.consumed_device = consumed_device class ConsumeCodeIncorrectUserInputCodeError: @@ -126,22 +170,42 @@ class UpdateUserPhoneNumberAlreadyExistsError: pass -class DeleteUserInfoOkResult: +class RevokeAllCodesOkResult: pass -class DeleteUserInfoUnknownUserIdError: +class RevokeCodeOkResult: pass -class RevokeAllCodesOkResult: +class CheckCodeOkResult: + def __init__(self, consumed_device: ConsumedDevice): + self.status = "OK" + self.consumed_device = consumed_device + + +class CheckCodeIncorrectUserInputCodeError(ConsumeCodeIncorrectUserInputCodeError): pass -class RevokeCodeOkResult: +class CheckCodeExpiredUserInputCodeError(ConsumeCodeExpiredUserInputCodeError): pass +class CheckCodeRestartFlowError(ConsumeCodeRestartFlowError): + pass + + +class EmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + +class PhoneNumberChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + class RecipeInterface(ABC): def __init__(self): pass @@ -152,6 +216,8 @@ async def create_code( email: Union[None, str], phone_number: Union[None, str], user_input_code: Union[None, str], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: @@ -178,6 +244,8 @@ async def consume_code( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> Union[ @@ -185,31 +253,31 @@ async def consume_code( ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: pass @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass - - @abstractmethod - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass - - @abstractmethod - async def get_user_by_phone_number( - self, phone_number: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: + async def check_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, + ]: pass @abstractmethod async def update_user( self, - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None], phone_number: Union[str, None], user_context: Dict[str, Any], @@ -218,19 +286,21 @@ async def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: pass @abstractmethod async def delete_email_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: pass @abstractmethod async def delete_phone_number_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: pass @abstractmethod @@ -337,21 +407,21 @@ def to_json(self): class ConsumeCodePostOkResult(APIResponse): status: str = "OK" - def __init__(self, created_new_user: bool, user: User, session: SessionContainer): - self.created_new_user = created_new_user + def __init__( + self, + created_new_recipe_user: bool, + user: User, + session: SessionContainer, + ): + self.created_new_recipe_user = created_new_recipe_user self.user = user self.session = session def to_json(self): - user = {"id": self.user.user_id, "time_joined": self.user.time_joined} - if self.user.email is not None: - user = {**user, "email": self.user.email} - if self.user.phone_number is not None: - user = {**user, "phoneNumber": self.user.phone_number} return { "status": self.status, - "createdNewUser": self.created_new_user, - "user": user, + "user": self.user.to_json(), + "createdNewRecipeUser": self.created_new_recipe_user, } @@ -416,6 +486,16 @@ def to_json(self): return {"status": self.status, "exists": self.exists} +class SignInUpPostNotAllowedResponse(APIResponse): + status: str = "SIGN_IN_UP_NOT_ALLOWED" + + def __init__(self, reason: str): + self.reason = reason + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "reason": self.reason} + + class APIInterface: def __init__(self): self.disable_create_code_post = False @@ -429,10 +509,14 @@ async def create_code_post( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], - ) -> Union[CreateCodePostOkResult, GeneralErrorResponse]: + ) -> Union[ + CreateCodePostOkResult, SignInUpPostNotAllowedResponse, GeneralErrorResponse + ]: pass @abstractmethod @@ -440,6 +524,8 @@ async def resend_code_post( self, device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -455,6 +541,8 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], @@ -464,6 +552,7 @@ async def consume_code_post( GeneralErrorResponse, ConsumeCodePostIncorrectUserInputCodeError, ConsumeCodePostExpiredUserInputCodeError, + SignInUpPostNotAllowedResponse, ]: pass diff --git a/supertokens_python/recipe/passwordless/recipe.py b/supertokens_python/recipe/passwordless/recipe.py index 0367710f3..ae0e752e7 100644 --- a/supertokens_python/recipe/passwordless/recipe.py +++ b/supertokens_python/recipe/passwordless/recipe.py @@ -15,17 +15,35 @@ from os import environ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Union, Optional +from supertokens_python.auth_utils import is_fake_email from supertokens_python.ingredients.emaildelivery import EmailDeliveryIngredient from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig from supertokens_python.ingredients.smsdelivery import SMSDeliveryIngredient from supertokens_python.querier import Querier +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import ( + FactorIds, + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetEmailsForFactorFromOtherRecipesFunc, + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + GetFactorsSetupForUserFromOtherRecipesFunc, + GetPhoneNumbersForFactorsFromOtherRecipesFunc, + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.recipe.passwordless.types import ( PasswordlessIngredients, PasswordlessLoginSMSTemplateVars, ) from typing_extensions import Literal +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import User, RecipeUserId + from .api import ( consume_code, create_code, @@ -54,14 +72,9 @@ from .utils import ( ContactConfig, OverrideConfig, + get_enabled_pwless_factors, validate_and_normalise_user_input, ) -from ..emailverification import EmailVerificationRecipe -from ..emailverification.interfaces import ( - GetEmailForUserIdOkResult, - EmailDoesNotExistError, - UnknownUserIdError, -) from ...post_init_callbacks import PostSTInitCallbacks if TYPE_CHECKING: @@ -142,9 +155,244 @@ def __init__( ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) + mfa_instance = MultiFactorAuthRecipe.get_instance() + all_factors = get_enabled_pwless_factors(self.config) + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return all_factors + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + + async def get_factors_setup_for_user( + user: User, _: Dict[str, Any] + ) -> List[str]: + def is_factor_setup_for_user(user: User, factor_id: str) -> bool: + for login_method in user.login_methods: + if login_method.recipe_id != "passwordless": + continue + + if login_method.email is not None and not is_fake_email( + login_method.email + ): + if factor_id in [ + FactorIds.OTP_EMAIL, + FactorIds.LINK_EMAIL, + ]: + return True + + if login_method.phone_number is not None: + if factor_id in [ + FactorIds.OTP_PHONE, + FactorIds.LINK_PHONE, + ]: + return True + return False + + return [ + factor_id + for factor_id in all_factors + if is_factor_setup_for_user(user, factor_id) + ] + + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc( + get_factors_setup_for_user + ) + ) + + async def get_emails_for_factor( + user: User, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetEmailsForFactorOkResult, + GetEmailsForFactorUnknownSessionRecipeUserIdResult, + ]: + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return GetEmailsForFactorUnknownSessionRecipeUserIdResult() + + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + + # MAIN LOGIC FOR THE FUNCTION STARTS HERE + non_fake_emails_passwordless = [ + lm.email + for lm in ordered_login_methods + if lm.recipe_id == "passwordless" + and lm.email is not None + and not is_fake_email(lm.email) + ] + + if not non_fake_emails_passwordless: + # This factor is not set up for email-based factors. + # We check for emails from other loginMethods and return those. + emails_result = [] + if ( + session_login_method.email is not None + and not is_fake_email(session_login_method.email) + ): + emails_result = [session_login_method.email] + + emails_result.extend( + [ + lm.email + for lm in ordered_login_methods + if lm.email is not None + and not is_fake_email(lm.email) + and lm.email not in emails_result + ] + ) + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.OTP_EMAIL] = emails_result + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.LINK_EMAIL + ] = emails_result + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + elif len(non_fake_emails_passwordless) == 1: + # Return just this email to avoid creating more loginMethods + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.OTP_EMAIL + ] = non_fake_emails_passwordless + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[ + FactorIds.LINK_EMAIL + ] = non_fake_emails_passwordless + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + + # Return all emails with passwordless login method, prioritizing session's email + emails_result = [] + if ( + session_login_method.email is not None + and session_login_method.email in non_fake_emails_passwordless + ): + emails_result = [session_login_method.email] + + emails_result.extend( + [ + email + for email in non_fake_emails_passwordless + if email not in emails_result + ] + ) + + factor_id_to_emails_map = {} + if FactorIds.OTP_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.OTP_EMAIL] = emails_result + if FactorIds.LINK_EMAIL in all_factors: + factor_id_to_emails_map[FactorIds.LINK_EMAIL] = emails_result + + return GetEmailsForFactorOkResult( + factor_id_to_emails_map=factor_id_to_emails_map + ) + + mfa_instance.add_func_to_get_emails_for_factor_from_other_recipes( + GetEmailsForFactorFromOtherRecipesFunc(get_emails_for_factor) + ) + + async def get_phone_numbers_for_factors( + user: User, session_recipe_user_id: RecipeUserId + ) -> Union[ + GetPhoneNumbersForFactorsOkResult, + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult, + ]: + session_login_method = next( + ( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() + == session_recipe_user_id.get_as_string() + ), + None, + ) + if session_login_method is None: + return ( + GetPhoneNumbersForFactorsUnknownSessionRecipeUserIdResult() + ) + + ordered_login_methods = sorted( + user.login_methods, key=lambda lm: lm.time_joined + ) + + phone_numbers = [ + lm.phone_number + for lm in ordered_login_methods + if lm.recipe_id == "passwordless" + and lm.phone_number is not None + ] + + if not phone_numbers: + phones_result = [] + if session_login_method.phone_number is not None: + phones_result = [session_login_method.phone_number] + + phones_result.extend( + [ + lm.phone_number + for lm in ordered_login_methods + if lm.phone_number is not None + and lm.phone_number not in phones_result + ] + ) + elif len(phone_numbers) == 1: + phones_result = phone_numbers + else: + phones_result = [] + if ( + session_login_method.phone_number is not None + and session_login_method.phone_number in phone_numbers + ): + phones_result = [session_login_method.phone_number] + phones_result.extend( + [ + phone + for phone in phone_numbers + if phone not in phones_result + ] + ) + + factor_id_to_phone_number_map = {} + if FactorIds.OTP_PHONE in all_factors: + factor_id_to_phone_number_map[ + FactorIds.OTP_PHONE + ] = phones_result + if FactorIds.LINK_PHONE in all_factors: + factor_id_to_phone_number_map[ + FactorIds.LINK_PHONE + ] = phones_result + + return GetPhoneNumbersForFactorsOkResult( + factor_id_to_phone_number_map=factor_id_to_phone_number_map + ) + + mfa_instance.add_func_to_get_phone_numbers_for_factors_from_other_recipes( + GetPhoneNumbersForFactorsFromOtherRecipesFunc( + get_phone_numbers_for_factors + ) + ) + + mt_recipe = MultitenancyRecipe.get_instance_optional() + if mt_recipe is not None: + for factor_id in all_factors: + mt_recipe.all_available_first_factors.append(factor_id) PostSTInitCallbacks.add_post_init_callback(callback) @@ -330,6 +578,8 @@ async def create_magic_link( tenant_id=tenant_id, user_input_code=user_input_code, user_context=user_context, + session=None, + should_try_linking_with_session_user=False, ) app_info = self.get_app_info() @@ -351,6 +601,7 @@ async def signinup( self, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], tenant_id: str, user_context: Dict[str, Any], ) -> ConsumeCodeOkResult: @@ -360,6 +611,8 @@ async def signinup( user_input_code=None, tenant_id=tenant_id, user_context=user_context, + session=session, + should_try_linking_with_session_user=False, ) consume_code_result = await self.recipe_implementation.consume_code( link_code=code_info.link_code, @@ -368,19 +621,9 @@ async def signinup( user_input_code=code_info.user_input_code, tenant_id=tenant_id, user_context=user_context, + session=session, + should_try_linking_with_session_user=False, ) if isinstance(consume_code_result, ConsumeCodeOkResult): return consume_code_result raise Exception("Failed to create user. Please retry") - - async def get_email_for_user_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, UnknownUserIdError]: - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - if user_info.email is not None: - return GetEmailForUserIdOkResult(user_info.email) - return EmailDoesNotExistError() - return UnknownUserIdError() diff --git a/supertokens_python/recipe/passwordless/recipe_implementation.py b/supertokens_python/recipe/passwordless/recipe_implementation.py index 19a75b97c..371c32a14 100644 --- a/supertokens_python/recipe/passwordless/recipe_implementation.py +++ b/supertokens_python/recipe/passwordless/recipe_implementation.py @@ -13,25 +13,32 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, +) from supertokens_python.querier import Querier - -from .types import DeviceCode, DeviceType, User - from supertokens_python.normalised_url_path import NormalisedURLPath - -from .interfaces import ( +from supertokens_python.recipe.passwordless.types import DeviceCode, DeviceType +from supertokens_python.recipe.passwordless.interfaces import ( + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeOkResult, ConsumeCodeRestartFlowError, + ConsumedDevice, CreateCodeOkResult, CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, RecipeInterface, RevokeAllCodesOkResult, RevokeCodeOkResult, @@ -40,6 +47,11 @@ UpdateUserPhoneNumberAlreadyExistsError, UpdateUserUnknownUserIdError, ) +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.types import User, RecipeUserId +from supertokens_python.utils import log_debug_message +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe class RecipeImplementation(RecipeInterface): @@ -47,34 +59,166 @@ def __init__(self, querier: Querier): super().__init__() self.querier = querier + async def consume_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + ConsumeCodeOkResult, + ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeExpiredUserInputCodeError, + ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, + ]: + input_dict = { + "preAuthSessionId": pre_auth_session_id, + } + if link_code is not None: + input_dict["linkCode"] = link_code + else: + if user_input_code is None or device_id is None: + return ConsumeCodeRestartFlowError() + input_dict["userInputCode"] = user_input_code + input_dict["deviceId"] = device_id + + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/consume"), + input_dict, + user_context=user_context, + ) + + if response["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": + return ConsumeCodeIncorrectUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "EXPIRED_USER_INPUT_CODE_ERROR": + return ConsumeCodeExpiredUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "RESTART_FLOW_ERROR": + return ConsumeCodeRestartFlowError() + + # status == "OK" + + log_debug_message("Passwordless.consumeCode code consumed OK") + + recipe_user_id = RecipeUserId(response["recipeUserId"]) + + updated_user = User.from_json(response["user"]) + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=updated_user, + recipe_user_id=recipe_user_id, + session=session, + user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if isinstance(link_result, LinkingToSessionUserFailedError): + return link_result + + updated_user = link_result.user + + response["user"] = updated_user + + return ConsumeCodeOkResult( + user=updated_user, + recipe_user_id=recipe_user_id, + consumed_device=ConsumedDevice.from_json(response["consumedDevice"]), + created_new_recipe_user=response["createdNewUser"], + ) + + async def check_code( + self, + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, + ]: + input_dict = { + "preAuthSessionId": pre_auth_session_id, + } + if link_code is not None: + input_dict["linkCode"] = link_code + else: + if user_input_code is None or device_id is None: + return CheckCodeRestartFlowError() + input_dict["userInputCode"] = user_input_code + input_dict["deviceId"] = device_id + + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/check"), + input_dict, + user_context=user_context, + ) + + if response["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": + return CheckCodeIncorrectUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "EXPIRED_USER_INPUT_CODE_ERROR": + return CheckCodeExpiredUserInputCodeError( + failed_code_input_attempt_count=response["failedCodeInputAttemptCount"], + maximum_code_input_attempts=response["maximumCodeInputAttempts"], + ) + elif response["status"] == "RESTART_FLOW_ERROR": + return CheckCodeRestartFlowError() + + # status == "OK" + log_debug_message("Passwordless.checkCode code verified") + + return CheckCodeOkResult( + consumed_device=ConsumedDevice.from_json(response["consumedDevice"]) + ) + async def create_code( self, email: Union[None, str], phone_number: Union[None, str], user_input_code: Union[None, str], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], ) -> CreateCodeOkResult: - data: Dict[str, Any] = {} - if user_input_code is not None: - data = {**data, "userInputCode": user_input_code} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "phoneNumber": phone_number} - result = await self.querier.send_post_request( + input_dict = {} + if email: + input_dict["email"] = email + if phone_number: + input_dict["phoneNumber"] = phone_number + if user_input_code: + input_dict["userInputCode"] = user_input_code + + response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup/code"), - data, + input_dict, user_context=user_context, ) return CreateCodeOkResult( - pre_auth_session_id=result["preAuthSessionId"], - code_id=result["codeId"], - device_id=result["deviceId"], - user_input_code=result["userInputCode"], - link_code=result["linkCode"], - time_created=result["timeCreated"], - code_life_time=result["codeLifetime"], + pre_auth_session_id=response["preAuthSessionId"], + code_id=response["codeId"], + device_id=response["deviceId"], + user_input_code=response["userInputCode"], + link_code=response["linkCode"], + code_life_time=response["codeLifetime"], + time_created=response["timeCreated"], ) async def create_new_code_for_device( @@ -110,226 +254,43 @@ async def create_new_code_for_device( time_created=result["timeCreated"], ) - async def consume_code( - self, - pre_auth_session_id: str, - user_input_code: Union[str, None], - device_id: Union[str, None], - link_code: Union[str, None], - tenant_id: str, - user_context: Dict[str, Any], - ) -> Union[ - ConsumeCodeOkResult, - ConsumeCodeIncorrectUserInputCodeError, - ConsumeCodeExpiredUserInputCodeError, - ConsumeCodeRestartFlowError, - ]: - data = {"preAuthSessionId": pre_auth_session_id} - if device_id is not None: - data = {**data, "deviceId": device_id, "userInputCode": user_input_code} - else: - data = {**data, "linkCode": link_code} - result = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/consume"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - email = None - phone_number = None - if "email" in result["user"]: - email = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number = result["user"]["phoneNumber"] - user = User( - user_id=result["user"]["id"], - email=email, - phone_number=phone_number, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], - ) - return ConsumeCodeOkResult(result["createdNewUser"], user) - if result["status"] == "RESTART_FLOW_ERROR": - return ConsumeCodeRestartFlowError() - if result["status"] == "INCORRECT_USER_INPUT_CODE_ERROR": - return ConsumeCodeIncorrectUserInputCodeError( - failed_code_input_attempt_count=result["failedCodeInputAttemptCount"], - maximum_code_input_attempts=result["maximumCodeInputAttempts"], - ) - return ConsumeCodeExpiredUserInputCodeError( - failed_code_input_attempt_count=result["failedCodeInputAttemptCount"], - maximum_code_input_attempts=result["maximumCodeInputAttempts"], - ) - - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"userId": user_id} + async def list_codes_by_device_id( + self, device_id: str, tenant_id: str, user_context: Dict[str, Any] + ) -> Union[DeviceType, None]: + param = {"deviceId": device_id} result = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), + NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), param, user_context=user_context, ) - if result["status"] == "OK": + if "devices" in result and len(result["devices"]) == 1: + codes: List[DeviceCode] = [] + if "code" in result["devices"][0]: + for code in result["devices"][0]: + codes.append( + DeviceCode( + code_id=code["codeId"], + time_created=code["timeCreated"], + code_life_time=code["codeLifetime"], + ) + ) email = None phone_number = None - if "email" in result["user"]: - email = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], + if "email" in result["devices"][0]: + email = result["devices"][0]["email"] + if "phoneNumber" in result["devices"][0]: + phone_number = result["devices"][0]["phoneNumber"] + return DeviceType( + pre_auth_session_id=result["devices"][0]["preAuthSessionId"], + failed_code_input_attempt_count=result["devices"][0][ + "failedCodeInputAttemptCount" + ], + codes=codes, email=email, phone_number=phone_number, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], - ) - return None - - async def get_user_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"email": email} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - param, - user_context=user_context, - ) - if result["status"] == "OK": - email_resp = None - phone_number_resp = None - if "email" in result["user"]: - email_resp = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number_resp = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], - email=email_resp, - phone_number=phone_number_resp, - tenant_ids=result["user"]["tenantIds"], - time_joined=result["user"]["timeJoined"], - ) - return None - - async def get_user_by_phone_number( - self, phone_number: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - param = {"phoneNumber": phone_number} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - param, - user_context=user_context, - ) - if result["status"] == "OK": - email_resp = None - phone_number_resp = None - if "email" in result["user"]: - email_resp = result["user"]["email"] - if "phoneNumber" in result["user"]: - phone_number_resp = result["user"]["phoneNumber"] - return User( - user_id=result["user"]["id"], - email=email_resp, - phone_number=phone_number_resp, - time_joined=result["user"]["timeJoined"], - tenant_ids=result["user"]["tenantIds"], ) return None - async def update_user( - self, - user_id: str, - email: Union[str, None], - phone_number: Union[str, None], - user_context: Dict[str, Any], - ) -> Union[ - UpdateUserOkResult, - UpdateUserUnknownUserIdError, - UpdateUserEmailAlreadyExistsError, - UpdateUserPhoneNumberAlreadyExistsError, - ]: - data = {"userId": user_id} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "phoneNumber": phone_number} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return UpdateUserOkResult() - if result["status"] == "UNKNOWN_USER_ID_ERROR": - return UpdateUserUnknownUserIdError() - if result["status"] == "EMAIL_ALREADY_EXISTS_ERROR": - return UpdateUserEmailAlreadyExistsError() - return UpdateUserPhoneNumberAlreadyExistsError() - - async def delete_email_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - data = {"userId": user_id, "email": None} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return DeleteUserInfoOkResult() - if result.get("EMAIL_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - if result.get("PHONE_NUMBER_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - return DeleteUserInfoUnknownUserIdError() - - async def delete_phone_number_for_user( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: - data = {"userId": user_id, "phoneNumber": None} - result = await self.querier.send_put_request( - NormalisedURLPath("/recipe/user"), - data, - user_context=user_context, - ) - if result["status"] == "OK": - return DeleteUserInfoOkResult() - if result.get("EMAIL_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - if result.get("PHONE_NUMBER_ALREADY_EXISTS_ERROR"): - raise Exception("Should never come here") - return DeleteUserInfoUnknownUserIdError() - - async def revoke_all_codes( - self, - email: Union[str, None], - phone_number: Union[str, None], - tenant_id: str, - user_context: Dict[str, Any], - ) -> RevokeAllCodesOkResult: - data: Dict[str, Any] = {} - if email is not None: - data = {**data, "email": email} - if phone_number is not None: - data = {**data, "email": phone_number} - await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes/remove"), - data, - user_context=user_context, - ) - return RevokeAllCodesOkResult() - - async def revoke_code( - self, code_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> RevokeCodeOkResult: - data = {"codeId": code_id} - await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/remove"), - data, - user_context=user_context, - ) - return RevokeCodeOkResult() - async def list_codes_by_email( self, email: str, tenant_id: str, user_context: Dict[str, Any] ) -> List[DeviceType]: @@ -412,10 +373,10 @@ async def list_codes_by_phone_number( ) return devices - async def list_codes_by_device_id( - self, device_id: str, tenant_id: str, user_context: Dict[str, Any] + async def list_codes_by_pre_auth_session_id( + self, pre_auth_session_id: str, tenant_id: str, user_context: Dict[str, Any] ) -> Union[DeviceType, None]: - param = {"deviceId": device_id} + param = {"preAuthSessionId": pre_auth_session_id} result = await self.querier.send_get_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), param, @@ -449,39 +410,145 @@ async def list_codes_by_device_id( ) return None - async def list_codes_by_pre_auth_session_id( - self, pre_auth_session_id: str, tenant_id: str, user_context: Dict[str, Any] - ) -> Union[DeviceType, None]: - param = {"preAuthSessionId": pre_auth_session_id} - result = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes"), - param, + async def revoke_all_codes( + self, + email: Union[str, None], + phone_number: Union[str, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> RevokeAllCodesOkResult: + data: Dict[str, Any] = {} + if email is not None: + data = {**data, "email": email} + if phone_number is not None: + data = {**data, "email": phone_number} + await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/codes/remove"), + data, user_context=user_context, ) - if "devices" in result and len(result["devices"]) == 1: - codes: List[DeviceCode] = [] - if "code" in result["devices"][0]: - for code in result["devices"][0]: - codes.append( - DeviceCode( - code_id=code["codeId"], - time_created=code["timeCreated"], - code_life_time=code["codeLifetime"], - ) + return RevokeAllCodesOkResult() + + async def revoke_code( + self, code_id: str, tenant_id: str, user_context: Dict[str, Any] + ) -> RevokeCodeOkResult: + data = {"codeId": code_id} + await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/signinup/code/remove"), + data, + user_context=user_context, + ) + return RevokeCodeOkResult() + + async def delete_email_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + data = {"recipeUserId": recipe_user_id.get_as_string(), "email": None} + result = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + data, + user_context=user_context, + ) + if result["status"] == "OK": + return UpdateUserOkResult() + return UpdateUserUnknownUserIdError() + + async def delete_phone_number_for_user( + self, recipe_user_id: RecipeUserId, user_context: Dict[str, Any] + ) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + data = {"recipeUserId": recipe_user_id.get_as_string(), "phoneNumber": None} + result = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + data, + user_context=user_context, + ) + if result["status"] == "OK": + return UpdateUserOkResult() + return UpdateUserUnknownUserIdError() + + async def update_user( + self, + recipe_user_id: RecipeUserId, + email: Union[str, None], + phone_number: Union[str, None], + user_context: Dict[str, Any], + ) -> Union[ + UpdateUserOkResult, + UpdateUserUnknownUserIdError, + UpdateUserEmailAlreadyExistsError, + UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + ]: + account_linking = AccountLinkingRecipe.get_instance() + if email: + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UpdateUserUnknownUserIdError() + + ev_instance = EmailVerificationRecipe.get_instance_optional() + is_email_verified = False + if ev_instance: + is_email_verified = ( + await ev_instance.recipe_implementation.is_email_verified( + recipe_user_id=recipe_user_id, + email=email, + user_context=user_context, ) - email = None - phone_number = None - if "email" in result["devices"][0]: - email = result["devices"][0]["email"] - if "phoneNumber" in result["devices"][0]: - phone_number = result["devices"][0]["phoneNumber"] - return DeviceType( - pre_auth_session_id=result["devices"][0]["preAuthSessionId"], - failed_code_input_attempt_count=result["devices"][0][ - "failedCodeInputAttemptCount" - ], - codes=codes, - email=email, - phone_number=phone_number, + ) + + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_email_verified, + new_email=email, + session=None, + user_context=user_context, ) - return None + if not is_email_change_allowed.allowed: + return EmailChangeNotAllowedError( + reason=( + "New email cannot be applied to existing account because of account takeover risks." + if is_email_change_allowed.reason == "ACCOUNT_TAKEOVER_RISK" + else "New email cannot be applied to existing account because there is another primary user with the same email address." + ), + ) + + input_dict = { + "recipeUserId": recipe_user_id.get_as_string(), + } + if email: + input_dict = {**input_dict, "email": email} + if phone_number: + input_dict = {**input_dict, "phoneNumber": phone_number} + + response = await self.querier.send_put_request( + NormalisedURLPath("/recipe/user"), + input_dict, + user_context=user_context, + ) + if response["status"] == "UNKNOWN_USER_ID_ERROR": + return UpdateUserUnknownUserIdError() + elif response["status"] == "EMAIL_ALREADY_EXISTS_ERROR": + return UpdateUserEmailAlreadyExistsError() + elif response["status"] == "PHONE_NUMBER_ALREADY_EXISTS_ERROR": + return UpdateUserPhoneNumberAlreadyExistsError() + elif response["status"] == "EMAIL_CHANGE_NOT_ALLOWED_ERROR": + return EmailChangeNotAllowedError( + reason=response["reason"], + ) + elif response["status"] == "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR": + return PhoneNumberChangeNotAllowedError( + reason=response["reason"], + ) + + # status is OK + user = await get_user(recipe_user_id.get_as_string(), user_context) + if user is None: + return UpdateUserUnknownUserIdError() + + await account_linking.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + return UpdateUserOkResult() diff --git a/supertokens_python/recipe/passwordless/syncio/__init__.py b/supertokens_python/recipe/passwordless/syncio/__init__.py index 7210f8153..e5a0105a4 100644 --- a/supertokens_python/recipe/passwordless/syncio/__init__.py +++ b/supertokens_python/recipe/passwordless/syncio/__init__.py @@ -11,10 +11,11 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union - +from typing import Any, Dict, List, Optional, Union from supertokens_python.async_to_sync_wrapper import sync +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.recipe.passwordless import asyncio +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.passwordless.interfaces import ( ConsumeCodeExpiredUserInputCodeError, ConsumeCodeIncorrectUserInputCodeError, @@ -24,21 +25,25 @@ CreateNewCodeForDeviceOkResult, CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, - DeleteUserInfoOkResult, - DeleteUserInfoUnknownUserIdError, RevokeAllCodesOkResult, RevokeCodeOkResult, UpdateUserEmailAlreadyExistsError, UpdateUserOkResult, UpdateUserPhoneNumberAlreadyExistsError, UpdateUserUnknownUserIdError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, + CheckCodeExpiredUserInputCodeError, + CheckCodeIncorrectUserInputCodeError, + CheckCodeOkResult, + CheckCodeRestartFlowError, ) from supertokens_python.recipe.passwordless.types import ( DeviceType, - PasswordlessLoginEmailTemplateVars, - PasswordlessLoginSMSTemplateVars, - User, + EmailTemplateVars, + SMSTemplateVars, ) +from supertokens_python.types import RecipeUserId def create_code( @@ -46,14 +51,18 @@ def create_code( email: Union[None, str] = None, phone_number: Union[None, str] = None, user_input_code: Union[None, str] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> CreateCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.create_code( tenant_id, email=email, phone_number=phone_number, user_input_code=user_input_code, + session=session, user_context=user_context, ) ) @@ -69,6 +78,8 @@ def create_new_code_for_device( CreateNewCodeForDeviceRestartFlowError, CreateNewCodeForDeviceUserInputCodeAlreadyUsedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.create_new_code_for_device( tenant_id, @@ -85,13 +96,17 @@ def consume_code( user_input_code: Union[str, None] = None, device_id: Union[str, None] = None, link_code: Union[str, None] = None, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[ ConsumeCodeOkResult, ConsumeCodeIncorrectUserInputCodeError, ConsumeCodeExpiredUserInputCodeError, ConsumeCodeRestartFlowError, + LinkingToSessionUserFailedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.consume_code( tenant_id, @@ -99,39 +114,14 @@ def consume_code( user_input_code=user_input_code, device_id=device_id, link_code=link_code, + session=session, user_context=user_context, ) ) -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - return sync(asyncio.get_user_by_id(user_id=user_id, user_context=user_context)) - - -def get_user_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - return sync( - asyncio.get_user_by_email(tenant_id, email=email, user_context=user_context) - ) - - -def get_user_by_phone_number( - tenant_id: str, - phone_number: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> Union[User, None]: - return sync( - asyncio.get_user_by_phone_number( - tenant_id=tenant_id, phone_number=phone_number, user_context=user_context - ) - ) - - def update_user( - user_id: str, + recipe_user_id: RecipeUserId, email: Union[str, None] = None, phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -140,10 +130,14 @@ def update_user( UpdateUserUnknownUserIdError, UpdateUserEmailAlreadyExistsError, UpdateUserPhoneNumberAlreadyExistsError, + EmailChangeNotAllowedError, + PhoneNumberChangeNotAllowedError, ]: + if user_context is None: + user_context = {} return sync( asyncio.update_user( - user_id=user_id, + recipe_user_id=recipe_user_id, email=email, phone_number=phone_number, user_context=user_context, @@ -152,18 +146,30 @@ def update_user( def delete_email_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} return sync( - asyncio.delete_email_for_user(user_id=user_id, user_context=user_context) + asyncio.delete_email_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) ) def delete_phone_number_for_user( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[DeleteUserInfoOkResult, DeleteUserInfoUnknownUserIdError]: + recipe_user_id: RecipeUserId, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[UpdateUserOkResult, UpdateUserUnknownUserIdError]: + if user_context is None: + user_context = {} return sync( - asyncio.delete_phone_number_for_user(user_id=user_id, user_context=user_context) + asyncio.delete_phone_number_for_user( + recipe_user_id=recipe_user_id, + user_context=user_context, + ) ) @@ -173,6 +179,8 @@ def revoke_all_codes( phone_number: Union[str, None] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> RevokeAllCodesOkResult: + if user_context is None: + user_context = {} return sync( asyncio.revoke_all_codes( tenant_id, email=email, phone_number=phone_number, user_context=user_context @@ -183,6 +191,8 @@ def revoke_all_codes( def revoke_code( tenant_id: str, code_id: str, user_context: Union[None, Dict[str, Any]] = None ) -> RevokeCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.revoke_code(tenant_id, code_id=code_id, user_context=user_context) ) @@ -191,6 +201,8 @@ def revoke_code( def list_codes_by_email( tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None ) -> List[DeviceType]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_email(tenant_id, email=email, user_context=user_context) ) @@ -199,6 +211,8 @@ def list_codes_by_email( def list_codes_by_phone_number( tenant_id: str, phone_number: str, user_context: Union[None, Dict[str, Any]] = None ) -> List[DeviceType]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_phone_number( tenant_id, phone_number=phone_number, user_context=user_context @@ -211,6 +225,8 @@ def list_codes_by_device_id( device_id: str, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[DeviceType, None]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_device_id( tenant_id=tenant_id, device_id=device_id, user_context=user_context @@ -223,6 +239,8 @@ def list_codes_by_pre_auth_session_id( pre_auth_session_id: str, user_context: Union[None, Dict[str, Any]] = None, ) -> Union[DeviceType, None]: + if user_context is None: + user_context = {} return sync( asyncio.list_codes_by_pre_auth_session_id( tenant_id=tenant_id, @@ -238,6 +256,8 @@ def create_magic_link( phone_number: Union[str, None], user_context: Union[None, Dict[str, Any]] = None, ) -> str: + if user_context is None: + user_context = {} return sync( asyncio.create_magic_link( tenant_id=tenant_id, @@ -252,27 +272,62 @@ def signinup( tenant_id: str, email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> ConsumeCodeOkResult: + if user_context is None: + user_context = {} return sync( asyncio.signinup( tenant_id=tenant_id, email=email, phone_number=phone_number, user_context=user_context, + session=session, ) ) def send_email( - input_: PasswordlessLoginEmailTemplateVars, + input_: EmailTemplateVars, user_context: Union[None, Dict[str, Any]] = None, -) -> None: +): + if user_context is None: + user_context = {} return sync(asyncio.send_email(input_, user_context)) def send_sms( - input_: PasswordlessLoginSMSTemplateVars, + input_: SMSTemplateVars, user_context: Union[None, Dict[str, Any]] = None, -) -> None: +): + if user_context is None: + user_context = {} return sync(asyncio.send_sms(input_, user_context)) + + +def check_code( + tenant_id: str, + pre_auth_session_id: str, + user_input_code: Union[str, None] = None, + device_id: Union[str, None] = None, + link_code: Union[str, None] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[ + CheckCodeOkResult, + CheckCodeIncorrectUserInputCodeError, + CheckCodeExpiredUserInputCodeError, + CheckCodeRestartFlowError, +]: + if user_context is None: + user_context = {} + return sync( + asyncio.check_code( + tenant_id, + pre_auth_session_id=pre_auth_session_id, + user_input_code=user_input_code, + device_id=device_id, + link_code=link_code, + user_context=user_context, + ) + ) diff --git a/supertokens_python/recipe/passwordless/types.py b/supertokens_python/recipe/passwordless/types.py index 38f7bcb82..03bc09cd6 100644 --- a/supertokens_python/recipe/passwordless/types.py +++ b/supertokens_python/recipe/passwordless/types.py @@ -25,32 +25,6 @@ ) -class User: - def __init__( - self, - user_id: str, - email: Union[str, None], - phone_number: Union[str, None], - time_joined: int, - tenant_ids: List[str], - ): - self.user_id = user_id - self.email = email - self.phone_number = phone_number - self.time_joined = time_joined - self.tenant_ids = tenant_ids - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.phone_number == other.phone_number - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - ) - - class DeviceCode: def __init__(self, code_id: str, time_created: str, code_life_time: int): self.code_id = code_id @@ -81,6 +55,7 @@ def __init__( code_life_time: int, pre_auth_session_id: str, email: str, + is_first_factor: bool, user_input_code: Union[str, None] = None, url_with_link_code: Union[str, None] = None, ): @@ -90,6 +65,7 @@ def __init__( self.user_input_code = user_input_code self.url_with_link_code = url_with_link_code self.tenant_id = tenant_id + self.is_first_factor = is_first_factor PasswordlessLoginEmailTemplateVars = CreateAndSendCustomEmailParameters @@ -102,6 +78,7 @@ def __init__( code_life_time: int, pre_auth_session_id: str, phone_number: str, + is_first_factor: bool, user_input_code: Union[str, None] = None, url_with_link_code: Union[str, None] = None, ): @@ -111,6 +88,7 @@ def __init__( self.user_input_code = user_input_code self.url_with_link_code = url_with_link_code self.tenant_id = tenant_id + self.is_first_factor = is_first_factor PasswordlessLoginSMSTemplateVars = CreateAndSendCustomTextMessageParameters diff --git a/supertokens_python/recipe/passwordless/utils.py b/supertokens_python/recipe/passwordless/utils.py index 9fb619c53..dcd85b39c 100644 --- a/supertokens_python/recipe/passwordless/utils.py +++ b/supertokens_python/recipe/passwordless/utils.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Union, List from supertokens_python.ingredients.emaildelivery.types import ( EmailDeliveryConfig, @@ -25,6 +25,7 @@ SMSDeliveryConfig, SMSDeliveryConfigWithService, ) +from supertokens_python.recipe.multifactorauth.types import FactorIds from supertokens_python.recipe.passwordless.types import ( PasswordlessLoginSMSTemplateVars, ) @@ -187,9 +188,9 @@ def validate_and_normalise_user_input( if override is None: override = OverrideConfig() - def get_email_delivery_config() -> EmailDeliveryConfigWithService[ - PasswordlessLoginEmailTemplateVars - ]: + def get_email_delivery_config() -> ( + EmailDeliveryConfigWithService[PasswordlessLoginEmailTemplateVars] + ): email_service = email_delivery.service if email_delivery is not None else None if email_service is None: @@ -202,9 +203,9 @@ def get_email_delivery_config() -> EmailDeliveryConfigWithService[ return EmailDeliveryConfigWithService(email_service, override=override) - def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ - PasswordlessLoginSMSTemplateVars - ]: + def get_sms_delivery_config() -> ( + SMSDeliveryConfigWithService[PasswordlessLoginSMSTemplateVars] + ): sms_service = sms_delivery.service if sms_delivery is not None else None if sms_service is None: @@ -240,3 +241,38 @@ def get_sms_delivery_config() -> SMSDeliveryConfigWithService[ get_sms_delivery_config=get_sms_delivery_config, get_custom_user_input_code=get_custom_user_input_code, ) + + +def get_enabled_pwless_factors( + config: PasswordlessConfig, +) -> List[str]: + all_factors: List[str] = [] + + if config.flow_type == "MAGIC_LINK": + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.LINK_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.LINK_PHONE] + else: + all_factors = [FactorIds.LINK_EMAIL, FactorIds.LINK_PHONE] + elif config.flow_type == "USER_INPUT_CODE": + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.OTP_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.OTP_PHONE] + else: + all_factors = [FactorIds.OTP_EMAIL, FactorIds.OTP_PHONE] + else: + if config.contact_config.contact_method == "EMAIL": + all_factors = [FactorIds.OTP_EMAIL, FactorIds.LINK_EMAIL] + elif config.contact_config.contact_method == "PHONE": + all_factors = [FactorIds.OTP_PHONE, FactorIds.LINK_PHONE] + else: + all_factors = [ + FactorIds.OTP_EMAIL, + FactorIds.OTP_PHONE, + FactorIds.LINK_EMAIL, + FactorIds.LINK_PHONE, + ] + + return all_factors diff --git a/supertokens_python/recipe/session/access_token.py b/supertokens_python/recipe/session/access_token.py index ac9ea51e7..1fa03950d 100644 --- a/supertokens_python/recipe/session/access_token.py +++ b/supertokens_python/recipe/session/access_token.py @@ -101,6 +101,7 @@ def get_info_from_access_token( user_data = payload session_handle = sanitize_string(payload.get("sessionHandle")) + recipe_user_id = sanitize_string(payload.get("rsub", user_id)) refresh_token_hash_1 = sanitize_string(payload.get("refreshTokenHash1")) parent_refresh_token_hash_1 = sanitize_string( payload.get("parentRefreshTokenHash1") @@ -129,6 +130,7 @@ def get_info_from_access_token( "expiryTime": expiry_time, "timeCreated": time_created, "tenantId": tenant_id, + "recipeUserId": recipe_user_id, } except Exception as e: log_debug_message( @@ -139,7 +141,40 @@ def get_info_from_access_token( def validate_access_token_structure(payload: Dict[str, Any], version: int) -> None: - if version >= 3: + if version >= 5: + if ( + not isinstance(payload.get("sub"), str) + or not isinstance(payload.get("exp"), (int, float)) + or not isinstance(payload.get("iat"), (int, float)) + or not isinstance(payload.get("sessionHandle"), str) + or not isinstance(payload.get("refreshTokenHash1"), str) + or not isinstance(payload.get("rsub"), str) + ): + log_debug_message( + "validateAccessTokenStructure: Access token is using version >= 5" + ) + # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. + raise Exception( + "Access token does not contain all the information. Maybe the structure has changed?" + ) + elif version >= 4: + if ( + not isinstance(payload.get("sub"), str) + or not isinstance(payload.get("exp"), (int, float)) + or not isinstance(payload.get("iat"), (int, float)) + or not isinstance(payload.get("sessionHandle"), str) + or not isinstance(payload.get("refreshTokenHash1"), str) + ): + log_debug_message( + "validateAccessTokenStructure: Access token is using version >= 4" + ) + # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. + raise Exception( + "Access token does not contain all the information. Maybe the structure has changed?" + ) + elif version >= 3: if ( not isinstance(payload.get("sub"), str) or not isinstance(payload.get("exp"), (int, float)) @@ -151,6 +186,7 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No "validateAccessTokenStructure: Access token is using version >= 3" ) # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) @@ -160,18 +196,19 @@ def validate_access_token_structure(payload: Dict[str, Any], version: int) -> No raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) - elif ( not isinstance(payload.get("sessionHandle"), str) - or payload.get("userData") is None + or not isinstance(payload.get("userId"), str) or not isinstance(payload.get("refreshTokenHash1"), str) - or not isinstance(payload.get("expiryTime"), (float, int)) - or not isinstance(payload.get("timeCreated"), (float, int)) + or payload.get("userData") is None + or not isinstance(payload.get("expiryTime"), (int, float)) + or not isinstance(payload.get("timeCreated"), (int, float)) ): log_debug_message( "validateAccessTokenStructure: Access token is using version < 3" ) # The error message below will be logged by the error handler that translates this into a TRY_REFRESH_TOKEN_ERROR + # it would come here if we change the structure of the JWT. raise Exception( "Access token does not contain all the information. Maybe the structure has changed?" ) diff --git a/supertokens_python/recipe/session/api/implementation.py b/supertokens_python/recipe/session/api/implementation.py index 1521243e8..2229f6ead 100644 --- a/supertokens_python/recipe/session/api/implementation.py +++ b/supertokens_python/recipe/session/api/implementation.py @@ -49,12 +49,11 @@ async def refresh_post( async def signout_post( self, - session: Optional[SessionContainer], + session: SessionContainer, api_options: APIOptions, user_context: Dict[str, Any], ) -> SignOutOkayResponse: - if session is not None: - await session.revoke_session(user_context) + await session.revoke_session(user_context) return SignOutOkayResponse() async def verify_session( diff --git a/supertokens_python/recipe/session/api/signout.py b/supertokens_python/recipe/session/api/signout.py index 5fcf69c07..9c0c68be0 100644 --- a/supertokens_python/recipe/session/api/signout.py +++ b/supertokens_python/recipe/session/api/signout.py @@ -48,6 +48,8 @@ async def handle_signout_api( user_context=user_context, ) + assert session is not None + response = await api_implementation.signout_post(session, api_options, user_context) if api_options.response is None: raise Exception("Should never come here") diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index a50422885..da67f9699 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -20,7 +20,6 @@ from supertokens_python.recipe.session.interfaces import ( ClaimsValidationResult, GetClaimValueOkResult, - JSONObject, SessionClaim, SessionClaimValidator, SessionContainer, @@ -28,7 +27,7 @@ SessionInformationResult, ) from supertokens_python.recipe.session.recipe import SessionRecipe -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import FRAMEWORKS, resolve from ...jwt.interfaces import ( @@ -45,6 +44,7 @@ from ..utils import get_required_claim_validators from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID +from supertokens_python.asyncio import get_user _T = TypeVar("_T") @@ -52,7 +52,7 @@ async def create_new_session( request: Any, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -68,12 +68,18 @@ async def create_new_session( config = recipe_instance.config app_info = recipe_instance.app_info + user = await get_user(recipe_user_id.get_as_string(), user_context) + user_id = recipe_user_id.get_as_string() + if user is not None: + user_id = user.id + return await create_new_session_in_request( request, user_context, recipe_instance, access_token_payload, user_id, + recipe_user_id, config, app_info, session_data_in_database, @@ -83,7 +89,7 @@ async def create_new_session( async def create_new_session_without_request_response( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, disable_anti_csrf: bool = False, @@ -111,12 +117,20 @@ async def create_new_session_without_request_response( if prop in final_access_token_payload: del final_access_token_payload[prop] + user = await get_user(recipe_user_id.get_as_string(), user_context) + user_id = recipe_user_id.get_as_string() + if user is not None: + user_id = user.id + for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, tenant_id, user_context) + update = await claim.build( + user_id, recipe_user_id, tenant_id, final_access_token_payload, user_context + ) final_access_token_payload = {**final_access_token_payload, **update} return await SessionRecipe.get_instance().recipe_implementation.create_new_session( user_id, + recipe_user_id, final_access_token_payload, session_data_in_database, disable_anti_csrf, @@ -157,6 +171,7 @@ async def validate_claims_for_session_handle( recipe_impl.get_global_claim_validators( session_info.tenant_id, session_info.user_id, + session_info.recipe_user_id, claim_validators_added_by_other_recipes, user_context, ) @@ -173,6 +188,7 @@ async def validate_claims_for_session_handle( claim_validation_res = await recipe_impl.validate_claims( session_info.user_id, + session_info.recipe_user_id, session_info.custom_claims_in_access_token_payload, claim_validators, user_context, @@ -190,53 +206,6 @@ async def validate_claims_for_session_handle( return ClaimsValidationResult(claim_validation_res.invalid_claims) -async def validate_claims_in_jwt_payload( - tenant_id: str, - user_id: str, - jwt_payload: JSONObject, - override_global_claim_validators: Optional[ - Callable[ - [ - List[SessionClaimValidator], - str, - Dict[str, Any], - ], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - - recipe_impl = SessionRecipe.get_instance().recipe_implementation - - claim_validators_added_by_other_recipes = ( - SessionRecipe.get_instance().get_claim_validators_added_by_other_recipes() - ) - global_claim_validators = await resolve( - recipe_impl.get_global_claim_validators( - tenant_id, - user_id, - claim_validators_added_by_other_recipes, - user_context, - ) - ) - - if override_global_claim_validators is not None: - claim_validators = await resolve( - override_global_claim_validators( - global_claim_validators, user_id, user_context - ) - ) - else: - claim_validators = global_claim_validators - - return await recipe_impl.validate_claims_in_jwt_payload( - user_id, jwt_payload, claim_validators, user_context - ) - - async def fetch_and_set_claim( session_handle: str, claim: SessionClaim[Any], @@ -435,25 +404,35 @@ async def revoke_session( async def revoke_all_sessions_for_user( user_id: str, + revoke_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: if user_context is None: user_context = {} return await SessionRecipe.get_instance().recipe_implementation.revoke_all_sessions_for_user( - user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context + user_id, + revoke_sessions_for_linked_accounts, + tenant_id or DEFAULT_TENANT_ID, + tenant_id is None, + user_context, ) async def get_all_session_handles_for_user( user_id: str, + fetch_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: if user_context is None: user_context = {} return await SessionRecipe.get_instance().recipe_implementation.get_all_session_handles_for_user( - user_id, tenant_id or DEFAULT_TENANT_ID, tenant_id is None, user_context + user_id, + fetch_sessions_for_linked_accounts, + tenant_id or DEFAULT_TENANT_ID, + tenant_id is None, + user_context, ) diff --git a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py index aec6c6f71..37af98f51 100644 --- a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py @@ -13,7 +13,7 @@ # under the License. from typing import Any, Callable, Dict, Optional -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from .primitive_claim import PrimitiveClaim, PrimitiveClaimValidators @@ -31,7 +31,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[bool]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py index f34afc750..5c13a02b6 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Optional, TypeVar, Union, Generic, List -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms from ..interfaces import ( @@ -267,7 +267,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[PrimitiveList]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py index b3a8ad08d..0ce0a2ceb 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import get_timestamp_ms from ..interfaces import ( @@ -132,7 +132,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[Primitive]], ], default_max_age_in_sec: Optional[int] = None, diff --git a/supertokens_python/recipe/session/exceptions.py b/supertokens_python/recipe/session/exceptions.py index 9c0d1552e..d9eedddc9 100644 --- a/supertokens_python/recipe/session/exceptions.py +++ b/supertokens_python/recipe/session/exceptions.py @@ -16,13 +16,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Union from supertokens_python.exceptions import SuperTokensError +from supertokens_python.types import RecipeUserId if TYPE_CHECKING: from .interfaces import ResponseMutator -def raise_token_theft_exception(user_id: str, session_handle: str) -> NoReturn: - raise TokenTheftError(user_id, session_handle) +def raise_token_theft_exception( + user_id: str, recipe_user_id: RecipeUserId, session_handle: str +) -> NoReturn: + raise TokenTheftError(user_id, recipe_user_id, session_handle) def raise_try_refresh_token_exception(ex: Union[str, Exception]) -> NoReturn: @@ -60,9 +63,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: class TokenTheftError(SuperTokensSessionError): - def __init__(self, user_id: str, session_handle: str): + def __init__(self, user_id: str, recipe_user_id: RecipeUserId, session_handle: str): super().__init__("token theft detected") self.user_id = user_id + self.recipe_user_id = recipe_user_id self.session_handle = session_handle @@ -83,12 +87,15 @@ def __init__(self, msg: str, payload: List[ClaimValidationError]): class ClaimValidationError: - def __init__(self, id_: str, reason: Optional[Dict[str, Any]]): - self.id = id_ - self.reason = reason + id_: str + reason: Optional[Union[str, Dict[str, Any]]] + + def __init__(self, id_: str, reason: Optional[Union[str, Dict[str, Any]]]): + self.id_: str = id_ + self.reason: Optional[Union[str, Dict[str, Any]]] = reason def to_json(self): - result: Dict[str, Any] = {"id": self.id} + result: Dict[str, Any] = {"id": self.id_} if self.reason is not None: result["reason"] = self.reason diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 31802fad5..0ffe24518 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -28,7 +28,12 @@ from typing_extensions import TypedDict from supertokens_python.async_to_sync_wrapper import sync -from supertokens_python.types import APIResponse, GeneralErrorResponse, MaybeAwaitable +from supertokens_python.types import ( + APIResponse, + GeneralErrorResponse, + MaybeAwaitable, + RecipeUserId, +) from ...utils import resolve from .exceptions import ClaimValidationError @@ -45,6 +50,7 @@ def __init__( self, handle: str, user_id: str, + recipe_user_id: RecipeUserId, user_data_in_jwt: Dict[str, Any], tenant_id: str, ): @@ -52,6 +58,16 @@ def __init__( self.user_id = user_id self.user_data_in_jwt = user_data_in_jwt self.tenant_id = tenant_id + self.recipe_user_id = recipe_user_id + + def to_json(self) -> Dict[str, Any]: + return { + "handle": self.handle, + "userId": self.user_id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "tenantId": self.tenant_id, + "userDataInJWT": self.user_data_in_jwt, + } class AccessTokenObj: @@ -60,18 +76,34 @@ def __init__(self, token: str, expiry: int, created_time: int): self.expiry = expiry self.created_time = created_time + def to_json(self) -> Dict[str, Any]: + return { + "token": self.token, + "expiry": self.expiry, + "createdTime": self.created_time, + } + class RegenerateAccessTokenOkResult: def __init__(self, session: SessionObj, access_token: Union[AccessTokenObj, None]): self.session = session self.access_token = access_token + def to_json(self) -> Dict[str, Any]: + return { + "session": self.session.to_json(), + "accessToken": ( + self.access_token.to_json() if self.access_token is not None else None + ), + } + class SessionInformationResult: def __init__( self, session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, session_data_in_database: Dict[str, Any], expiry: int, custom_claims_in_access_token_payload: Dict[str, Any], @@ -87,6 +119,19 @@ def __init__( ) self.time_created = time_created self.tenant_id = tenant_id + self.recipe_user_id = recipe_user_id + + def to_json(self) -> Dict[str, Any]: + return { + "sessionHandle": self.session_handle, + "userId": self.user_id, + "recipeUserId": self.recipe_user_id.get_as_string(), + "sessionDataInDatabase": self.session_data_in_database, + "expiry": self.expiry, + "customClaimsInAccessTokenPayload": self.custom_claims_in_access_token_payload, + "timeCreated": self.time_created, + "tenantId": self.tenant_id, + } class ReqResInfo: @@ -126,6 +171,13 @@ def __init__( self.invalid_claims = invalid_claims self.access_token_payload_update = access_token_payload_update + def to_json(self) -> Dict[str, Any]: + return { + "status": "OK", + "invalidClaims": [i.to_json() for i in self.invalid_claims], + "accessTokenPayloadUpdate": self.access_token_payload_update, + } + class GetSessionTokensDangerouslyDict(TypedDict): accessToken: str @@ -143,6 +195,7 @@ def __init__(self): async def create_new_session( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -156,6 +209,7 @@ def get_global_claim_validators( self, tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: @@ -183,22 +237,13 @@ async def get_session( async def validate_claims( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Dict[str, Any], claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> ClaimsValidationResult: pass - @abstractmethod - async def validate_claims_in_jwt_payload( - self, - user_id: str, - jwt_payload: JSONObject, - claim_validators: List[SessionClaimValidator], - user_context: Dict[str, Any], - ) -> ClaimsValidationResult: - pass - @abstractmethod async def refresh_session( self, @@ -219,6 +264,7 @@ async def revoke_session( async def revoke_all_sessions_for_user( self, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: str, revoke_across_all_tenants: bool, user_context: Dict[str, Any], @@ -229,6 +275,7 @@ async def revoke_all_sessions_for_user( async def get_all_session_handles_for_user( self, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: str, fetch_across_all_tenants: bool, user_context: Dict[str, Any], @@ -354,7 +401,7 @@ async def refresh_post( @abstractmethod async def signout_post( self, - session: Optional[SessionContainer], + session: SessionContainer, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[SignOutOkayResponse, GeneralErrorResponse]: @@ -387,6 +434,13 @@ def __init__(self, token: str, expiry: int, created_time: int): self.expiry = expiry self.created_time = created_time + def to_json(self) -> Dict[str, Any]: + return { + "token": self.token, + "expiry": self.expiry, + "createdTime": self.created_time, + } + class SessionContainer(ABC): # pylint: disable=too-many-public-methods def __init__( @@ -399,6 +453,7 @@ def __init__( anti_csrf_token: Optional[str], session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, user_data_in_access_token: Optional[Dict[str, Any]], req_res_info: Optional[ReqResInfo], access_token_updated: bool, @@ -416,7 +471,7 @@ def __init__( self.req_res_info: Optional[ReqResInfo] = req_res_info self.access_token_updated = access_token_updated self.tenant_id = tenant_id - + self.recipe_user_id = recipe_user_id self.response_mutators: List[ResponseMutator] = [] @abstractmethod @@ -460,6 +515,12 @@ async def merge_into_access_token_payload( def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass + @abstractmethod + def get_recipe_user_id( + self, user_context: Optional[Dict[str, Any]] = None + ) -> RecipeUserId: + pass + @abstractmethod def get_tenant_id(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass @@ -601,11 +662,11 @@ def sync_remove_claim( def sync_attach_to_request_response( self, request: BaseRequest, - token_transfer: TokenTransferMethod, + transfer_method: TokenTransferMethod, user_context: Dict[str, Any], ) -> None: return sync( - self.attach_to_request_response(request, token_transfer, user_context) + self.attach_to_request_response(request, transfer_method, user_context) ) # This is there so that we can do session["..."] to access some of the members of this class @@ -618,7 +679,7 @@ def __init__( self, key: str, fetch_value: Callable[ - [str, str, Dict[str, Any]], + [str, RecipeUserId, str, Dict[str, Any], Dict[str, Any]], MaybeAwaitable[Optional[_T]], ], ) -> None: @@ -663,13 +724,16 @@ def get_value_from_payload( async def build( self, user_id: str, + recipe_user_id: RecipeUserId, tenant_id: str, - user_context: Optional[Dict[str, Any]] = None, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> JSONObject: - if user_context is None: - user_context = {} - - value = await resolve(self.fetch_value(user_id, tenant_id, user_context)) + value = await resolve( + self.fetch_value( + user_id, recipe_user_id, tenant_id, current_payload, user_context + ) + ) if value is None: return {} diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 6fa02621c..7e4eb4799 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -263,7 +263,7 @@ async def handle_error( response, self, request, user_context ) return await self.config.error_handlers.on_token_theft_detected( - request, err.session_handle, err.user_id, response + request, err.session_handle, err.user_id, err.recipe_user_id, response ) if isinstance(err, InvalidClaimsError): log_debug_message("errorHandler: returning INVALID_CLAIMS") diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index ec02f8bbf..363259eff 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -20,7 +20,7 @@ from supertokens_python.normalised_url_path import NormalisedURLPath from supertokens_python.utils import resolve -from ...types import MaybeAwaitable +from ...types import MaybeAwaitable, RecipeUserId from . import session_functions from .access_token import validate_access_token_structure from .cookie_and_header import build_front_token @@ -29,7 +29,6 @@ AccessTokenObj, ClaimsValidationResult, GetClaimValueOkResult, - JSONObject, RecipeInterface, RegenerateAccessTokenOkResult, SessionClaim, @@ -62,6 +61,7 @@ def __init__(self, querier: Querier, config: SessionConfig, app_info: AppInfo): async def create_new_session( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -73,7 +73,7 @@ async def create_new_session( result = await session_functions.create_new_session( self, tenant_id, - user_id, + recipe_user_id, disable_anti_csrf is True, access_token_payload, session_data_in_database, @@ -96,6 +96,7 @@ async def create_new_session( result.antiCsrfToken, result.session.handle, result.session.userId, + recipe_user_id, payload, None, True, @@ -107,6 +108,7 @@ async def create_new_session( async def validate_claims( self, user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Dict[str, Any], claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], @@ -128,14 +130,16 @@ async def validate_claims( value = await resolve( validator.claim.fetch_value( user_id, + recipe_user_id, access_token_payload.get("tId", DEFAULT_TENANT_ID), + access_token_payload, user_context, ) ) log_debug_message( "update_claims_in_payload_if_needed %s refetch result %s", validator.id, - json.dumps(value), + value, ) if value is not None: access_token_payload = validator.claim.add_to_payload_( @@ -151,21 +155,6 @@ async def validate_claims( return ClaimsValidationResult(invalid_claims, access_token_payload_update) - async def validate_claims_in_jwt_payload( - self, - user_id: str, - jwt_payload: JSONObject, - claim_validators: List[SessionClaimValidator], - user_context: Dict[str, Any], - ) -> ClaimsValidationResult: - invalid_claims = await validate_claims_in_payload( - claim_validators, - jwt_payload, - user_context, - ) - - return ClaimsValidationResult(invalid_claims) - async def get_session( self, access_token: Optional[str], @@ -266,6 +255,7 @@ async def get_session( anti_csrf_token, response.session.handle, response.session.userId, + response.session.recipe_user_id, payload, None, access_token_updated, @@ -320,6 +310,7 @@ async def refresh_session( response.antiCsrfToken, response.session.handle, response.session.userId, + response.session.recipe_user_id, user_data_in_access_token=payload, req_res_info=None, access_token_updated=True, @@ -338,23 +329,35 @@ async def revoke_session( async def revoke_all_sessions_for_user( self, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: Optional[str], revoke_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.revoke_all_sessions_for_user( - self, user_id, tenant_id, revoke_across_all_tenants, user_context + self, + user_id, + revoke_sessions_for_linked_accounts, + tenant_id, + revoke_across_all_tenants, + user_context, ) async def get_all_session_handles_for_user( self, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: Optional[str], fetch_across_all_tenants: bool, user_context: Dict[str, Any], ) -> List[str]: return await session_functions.get_all_session_handles_for_user( - self, user_id, tenant_id, fetch_across_all_tenants, user_context + self, + user_id, + fetch_sessions_for_linked_accounts, + tenant_id, + fetch_across_all_tenants, + user_context, ) async def revoke_multiple_sessions( @@ -419,7 +422,11 @@ async def fetch_and_set_claim( return False access_token_payload_update = await claim.build( - session_info.user_id, session_info.tenant_id, user_context + session_info.user_id, + session_info.recipe_user_id, + session_info.tenant_id, + session_info.custom_claims_in_access_token_payload, + user_context, ) return await self.merge_into_access_token_payload( session_handle, access_token_payload_update, user_context @@ -457,6 +464,7 @@ def get_global_claim_validators( self, tenant_id: str, user_id: str, + recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> MaybeAwaitable[List[SessionClaimValidator]]: @@ -498,6 +506,7 @@ async def regenerate_access_token( session = SessionObj( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ) diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index e55a6c5da..6150fb0e7 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -17,6 +17,7 @@ raise_invalid_claims_exception, raise_unauthorised_exception, ) +from supertokens_python.types import RecipeUserId from .jwt import parse_jwt_without_signature_verification from .utils import TokenTransferMethod @@ -139,6 +140,11 @@ async def update_session_data_in_database( def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str: return self.user_id + def get_recipe_user_id( + self, user_context: Union[Dict[str, Any], None] = None + ) -> RecipeUserId: + return self.recipe_user_id + def get_tenant_id(self, user_context: Union[Dict[str, Any], None] = None) -> str: return self.tenant_id @@ -157,9 +163,9 @@ def get_all_session_tokens_dangerously(self) -> GetSessionTokensDangerouslyDict: return { "accessToken": self.access_token, "accessAndFrontTokenUpdated": self.access_token_updated, - "refreshToken": None - if self.refresh_token is None - else self.refresh_token.token, + "refreshToken": ( + None if self.refresh_token is None else self.refresh_token.token + ), "frontToken": self.front_token, "antiCsrfToken": self.anti_csrf_token, } @@ -204,6 +210,7 @@ async def assert_claims( validate_claim_res = await self.recipe_implementation.validate_claims( self.get_user_id(user_context), + self.get_recipe_user_id(user_context), self.get_access_token_payload(user_context), claim_validators, user_context, @@ -230,7 +237,11 @@ async def fetch_and_set_claim( user_context = {} update = await claim.build( - self.get_user_id(), self.get_tenant_id(), user_context + self.get_user_id(user_context=user_context), + self.get_recipe_user_id(user_context=user_context), + self.get_tenant_id(user_context=user_context), + self.get_access_token_payload(user_context=user_context), + user_context, ) return await self.merge_into_access_token_payload(update, user_context) diff --git a/supertokens_python/recipe/session/session_functions.py b/supertokens_python/recipe/session/session_functions.py index b2e214721..a09200633 100644 --- a/supertokens_python/recipe/session/session_functions.py +++ b/supertokens_python/recipe/session/session_functions.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional from supertokens_python.recipe.session.interfaces import SessionInformationResult +from supertokens_python.types import RecipeUserId from .access_token import get_info_from_access_token from .jwt import ParsedJWTInfo @@ -26,7 +27,7 @@ from supertokens_python.logger import log_debug_message from supertokens_python.normalised_url_path import NormalisedURLPath -from supertokens_python.process_state import AllowedProcessStates, ProcessState +from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe.session.interfaces import TokenInfo from .exceptions import ( @@ -40,9 +41,17 @@ class CreateOrRefreshAPIResponseSession: - def __init__(self, handle: str, userId: str, userDataInJWT: Any, tenant_id: str): + def __init__( + self, + handle: str, + userId: str, + recipe_user_id: RecipeUserId, + userDataInJWT: Any, + tenant_id: str, + ): self.handle = handle self.userId = userId + self.recipe_user_id = recipe_user_id self.userDataInJWT = userDataInJWT self.tenant_id = tenant_id @@ -66,12 +75,14 @@ def __init__( self, handle: str, userId: str, + recipe_user_id: RecipeUserId, userDataInJWT: Dict[str, Any], expiryTime: int, tenant_id: str, ) -> None: self.handle = handle self.userId = userId + self.recipe_user_id = recipe_user_id self.userDataInJWT = userDataInJWT self.expiryTime = expiryTime self.tenant_id = tenant_id @@ -97,7 +108,7 @@ def __init__( async def create_new_session( recipe_implementation: RecipeImplementation, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, disable_anti_csrf: bool, access_token_payload: Union[None, Dict[str, Any]], session_data_in_database: Union[None, Dict[str, Any]], @@ -115,7 +126,7 @@ async def create_new_session( response = await recipe_implementation.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/session"), { - "userId": user_id, + "userId": recipe_user_id.get_as_string(), "userDataInJWT": access_token_payload, "userDataInDatabase": session_data_in_database, "useDynamicSigningKey": recipe_implementation.config.use_dynamic_access_token_signing_key, @@ -128,6 +139,7 @@ async def create_new_session( CreateOrRefreshAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ), @@ -264,15 +276,14 @@ async def get_session( GetSessionAPIResponseSession( access_token_info["sessionHandle"], access_token_info["userId"], + RecipeUserId(access_token_info["recipeUserId"]), access_token_info["userData"], access_token_info["expiryTime"], access_token_info["tenantId"], ) ) - ProcessState.get_instance().add_state( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - ) + ProcessState.get_instance().add_state(PROCESS_STATE.CALLING_SERVICE_IN_VERIFY) data = { "accessToken": parsed_access_token.raw_token_string, @@ -293,6 +304,7 @@ async def get_session( GetSessionAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], ( response.get("accessToken", {}).get( @@ -371,6 +383,7 @@ async def refresh_session( CreateOrRefreshAPIResponseSession( response["session"]["handle"], response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), response["session"]["userDataInJWT"], response["session"]["tenantId"], ), @@ -395,13 +408,16 @@ async def refresh_session( "refreshSession: Returning TOKEN_THEFT_DETECTED because of core response" ) raise_token_theft_exception( - response["session"]["userId"], response["session"]["handle"] + response["session"]["userId"], + RecipeUserId(response["session"]["recipeUserId"]), + response["session"]["handle"], ) async def revoke_all_sessions_for_user( recipe_implementation: RecipeImplementation, user_id: str, + revoke_sessions_for_linked_accounts: bool, tenant_id: Optional[str], revoke_across_all_tenants: bool, user_context: Optional[Dict[str, Any]], @@ -412,13 +428,21 @@ async def revoke_all_sessions_for_user( if revoke_across_all_tenants: response = await recipe_implementation.querier.send_post_request( NormalisedURLPath("/recipe/session/remove"), - {"userId": user_id, "revokeAcrossAllTenants": revoke_across_all_tenants}, + { + "userId": user_id, + "revokeAcrossAllTenants": revoke_across_all_tenants, + "revokeSessionsForLinkedAccounts": revoke_sessions_for_linked_accounts, + }, user_context=user_context, ) else: response = await recipe_implementation.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/session/remove"), - {"userId": user_id, "revokeAcrossAllTenants": revoke_across_all_tenants}, + { + "userId": user_id, + "revokeAcrossAllTenants": revoke_across_all_tenants, + "revokeSessionsForLinkedAccounts": revoke_sessions_for_linked_accounts, + }, user_context=user_context, ) return response["sessionHandlesRevoked"] @@ -427,6 +451,7 @@ async def revoke_all_sessions_for_user( async def get_all_session_handles_for_user( recipe_implementation: RecipeImplementation, user_id: str, + fetch_sessions_for_linked_accounts: bool, tenant_id: Optional[str], fetch_across_all_tenants: bool, user_context: Optional[Dict[str, Any]], @@ -437,13 +462,21 @@ async def get_all_session_handles_for_user( if fetch_across_all_tenants: response = await recipe_implementation.querier.send_get_request( NormalisedURLPath("/recipe/session/user"), - {"userId": user_id, "fetchAcrossAllTenants": fetch_across_all_tenants}, + { + "userId": user_id, + "fetchAcrossAllTenants": fetch_across_all_tenants, + "fetchSessionsForAllLinkedAccounts": fetch_sessions_for_linked_accounts, + }, user_context=user_context, ) else: response = await recipe_implementation.querier.send_get_request( NormalisedURLPath(f"{tenant_id}/recipe/session/user"), - {"userId": user_id, "fetchAcrossAllTenants": fetch_across_all_tenants}, + { + "userId": user_id, + "fetchAcrossAllTenants": fetch_across_all_tenants, + "fetchSessionsForAllLinkedAccounts": fetch_sessions_for_linked_accounts, + }, user_context=user_context, ) return response["sessionHandles"] @@ -523,6 +556,7 @@ async def get_session_information( return SessionInformationResult( response["sessionHandle"], response["userId"], + RecipeUserId(response["recipeUserId"]), response["userDataInDatabase"], response["expiry"], response["userDataInJWT"], diff --git a/supertokens_python/recipe/session/session_request_functions.py b/supertokens_python/recipe/session/session_request_functions.py index 2fd8fcfa9..7d70cb2ee 100644 --- a/supertokens_python/recipe/session/session_request_functions.py +++ b/supertokens_python/recipe/session/session_request_functions.py @@ -54,7 +54,7 @@ get_required_claim_validators, get_auth_mode_from_header, ) -from supertokens_python.types import MaybeAwaitable +from supertokens_python.types import MaybeAwaitable, RecipeUserId from supertokens_python.utils import ( FRAMEWORKS, get_rid_from_header, @@ -239,6 +239,7 @@ async def create_new_session_in_request( recipe_instance: SessionRecipe, access_token_payload: Dict[str, Any], user_id: str, + recipe_user_id: RecipeUserId, config: SessionConfig, app_info: AppInfo, session_data_in_database: Dict[str, Any], @@ -268,7 +269,9 @@ async def create_new_session_in_request( del final_access_token_payload[prop] for claim in claims_added_by_other_recipes: - update = await claim.build(user_id, tenant_id, user_context) + update = await claim.build( + user_id, recipe_user_id, tenant_id, final_access_token_payload, user_context + ) final_access_token_payload.update(update) log_debug_message("createNewSession: Access token payload built") @@ -314,6 +317,7 @@ async def create_new_session_in_request( disable_anti_csrf = output_transfer_method == "header" session = await recipe_instance.recipe_implementation.create_new_session( user_id, + recipe_user_id, final_access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/supertokens_python/recipe/session/syncio/__init__.py b/supertokens_python/recipe/session/syncio/__init__.py index 71c87545e..50eec7a11 100644 --- a/supertokens_python/recipe/session/syncio/__init__.py +++ b/supertokens_python/recipe/session/syncio/__init__.py @@ -30,17 +30,17 @@ SessionInformationResult, SessionClaimValidator, SessionClaim, - JSONObject, ClaimsValidationResult, SessionDoesNotExistError, GetClaimValueOkResult, ) +from supertokens_python.recipe.session.recipe_implementation import RecipeUserId def create_new_session( request: Any, tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, user_context: Union[None, Dict[str, Any]] = None, @@ -51,9 +51,9 @@ def create_new_session( return sync( async_create_new_session( - tenant_id=tenant_id, request=request, - user_id=user_id, + tenant_id=tenant_id, + recipe_user_id=recipe_user_id, access_token_payload=access_token_payload, session_data_in_database=session_data_in_database, user_context=user_context, @@ -63,7 +63,7 @@ def create_new_session( def create_new_session_without_request_response( tenant_id: str, - user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None] = None, session_data_in_database: Union[Dict[str, Any], None] = None, disable_anti_csrf: bool = False, @@ -76,7 +76,7 @@ def create_new_session_without_request_response( return sync( async_create_new_session_without_request_response( tenant_id, - user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -187,6 +187,7 @@ def revoke_session( def revoke_all_sessions_for_user( user_id: str, + revoke_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: @@ -194,11 +195,19 @@ def revoke_all_sessions_for_user( revoke_all_sessions_for_user as async_revoke_all_sessions_for_user, ) - return sync(async_revoke_all_sessions_for_user(user_id, tenant_id, user_context)) + return sync( + async_revoke_all_sessions_for_user( + user_id, + revoke_sessions_for_linked_accounts, + tenant_id, + user_context, + ) + ) def get_all_session_handles_for_user( user_id: str, + fetch_sessions_for_linked_accounts: bool = True, tenant_id: Optional[str] = None, user_context: Union[None, Dict[str, Any]] = None, ) -> List[str]: @@ -207,7 +216,12 @@ def get_all_session_handles_for_user( ) return sync( - async_get_all_session_handles_for_user(user_id, tenant_id, user_context) + async_get_all_session_handles_for_user( + user_id, + fetch_sessions_for_linked_accounts, + tenant_id, + user_context, + ) ) @@ -365,30 +379,3 @@ def validate_claims_for_session_handle( session_handle, override_global_claim_validators, user_context ) ) - - -def validate_claims_in_jwt_payload( - tenant_id: str, - user_id: str, - jwt_payload: JSONObject, - override_global_claim_validators: Optional[ - Callable[ - [List[SessionClaimValidator], str, Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ] = None, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.session.asyncio import ( - validate_claims_in_jwt_payload as async_validate_claims_in_jwt_payload, - ) - - return sync( - async_validate_claims_in_jwt_payload( - tenant_id, - user_id, - jwt_payload, - override_global_claim_validators, - user_context, - ) - ) diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 5577147fa..96e5c43a4 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -30,7 +30,7 @@ send_non_200_response_with_message, ) -from ...types import MaybeAwaitable +from ...types import MaybeAwaitable, RecipeUserId from .constants import AUTH_MODE_HEADER_KEY, SESSION_REFRESH from .exceptions import ClaimValidationError @@ -100,7 +100,7 @@ class ErrorHandlers: def __init__( self, on_token_theft_detected: Callable[ - [BaseRequest, str, str, BaseResponse], + [BaseRequest, str, str, RecipeUserId, BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], on_try_refresh_token: Callable[ @@ -131,10 +131,13 @@ async def on_token_theft_detected( request: BaseRequest, session_handle: str, user_id: str, + recipe_user_id: RecipeUserId, response: BaseResponse, ) -> BaseResponse: return await resolve( - self.__on_token_theft_detected(request, session_handle, user_id, response) + self.__on_token_theft_detected( + request, session_handle, user_id, recipe_user_id, response + ) ) async def on_try_refresh_token( @@ -182,7 +185,7 @@ def __init__( on_token_theft_detected: Union[ None, Callable[ - [BaseRequest, str, str, BaseResponse], + [BaseRequest, str, str, RecipeUserId, BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], ] = None, @@ -261,7 +264,11 @@ async def default_try_refresh_token_callback( async def default_token_theft_detected_callback( - _: BaseRequest, session_handle: str, __: str, response: BaseResponse + _: BaseRequest, + session_handle: str, + __: str, + ___: RecipeUserId, + response: BaseResponse, ) -> BaseResponse: from .recipe import SessionRecipe @@ -576,6 +583,7 @@ async def get_required_claim_validators( SessionRecipe.get_instance().recipe_implementation.get_global_claim_validators( session.get_tenant_id(), session.get_user_id(), + session.get_recipe_user_id(), claim_validators_added_by_other_recipes, user_context, ) diff --git a/supertokens_python/recipe/thirdparty/api/implementation.py b/supertokens_python/recipe/thirdparty/api/implementation.py index c9daa33f0..75713db13 100644 --- a/supertokens_python/recipe/thirdparty/api/implementation.py +++ b/supertokens_python/recipe/thirdparty/api/implementation.py @@ -17,20 +17,21 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union from urllib.parse import parse_qs, urlencode, urlparse +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId from supertokens_python.recipe.emailverification import EmailVerificationRecipe -from supertokens_python.recipe.emailverification.interfaces import ( - CreateEmailVerificationTokenOkResult, -) -from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.recipe.emailverification.asyncio import is_email_verified +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface, AuthorisationUrlGetOkResult, + SignInUpNotAllowed, + SignInUpOkResult, SignInUpPostNoEmailGivenByProviderResponse, SignInUpPostOkResult, ) from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo -from supertokens_python.recipe.thirdparty.types import UserInfoEmail +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo, UserInfoEmail if TYPE_CHECKING: from supertokens_python.recipe.thirdparty.interfaces import APIOptions @@ -62,14 +63,38 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ SignInUpPostOkResult, SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, GeneralErrorResponse, ]: + from supertokens_python.auth_utils import ( + OkResponse, + PostAuthChecksOkResponse, + SignInNotAllowedResponse, + SignUpNotAllowedResponse, + get_authenticating_user_and_add_to_current_tenant_if_required, + post_auth_checks, + pre_auth_checks, + ) + + error_code_map = { + "SIGN_UP_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_006)", + "SIGN_IN_NOT_ALLOWED": "Cannot sign in / up due to security reasons. Please try a different login method or contact support. (ERR_CODE_004)", + "LINKING_TO_SESSION_USER_FAILED": { + "EMAIL_VERIFICATION_REQUIRED": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_020)", + "RECIPE_USER_ID_ALREADY_LINKED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_021)", + "ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_022)", + "SESSION_USER_ACCOUNT_INFO_ALREADY_ASSOCIATED_WITH_ANOTHER_PRIMARY_USER_ID_ERROR": "Cannot sign in / up due to security reasons. Please contact support. (ERR_CODE_023)", + }, + } + oauth_tokens_to_use: Dict[str, Any] = {} if redirect_uri_info is not None: @@ -77,8 +102,10 @@ async def sign_in_up_post( redirect_uri_info=redirect_uri_info, user_context=user_context, ) + elif oauth_tokens is not None: + oauth_tokens_to_use = oauth_tokens else: - oauth_tokens_to_use = oauth_tokens # type: ignore + raise Exception("should never come here") user_info = await provider.get_user_info( oauth_tokens=oauth_tokens_to_use, @@ -88,60 +115,147 @@ async def sign_in_up_post( if user_info.email is None and provider.config.require_email is False: # We don't expect to get an email from this provider. # So we generate a fake one - if provider.config.generate_fake_email is not None: - user_info.email = UserInfoEmail( - email=await provider.config.generate_fake_email( - tenant_id, user_info.third_party_user_id, user_context - ), - is_verified=True, + assert provider.config.generate_fake_email is not None + user_info.email = UserInfoEmail( + email=await provider.config.generate_fake_email( + tenant_id, user_info.third_party_user_id, user_context + ), + is_verified=True, + ) + + email_info = user_info.email + if email_info is None: + return SignInUpPostNoEmailGivenByProviderResponse() + + recipe_id = "thirdparty" + + async def check_credentials_on_tenant(_: str): + # We essentially did this above when calling exchange_auth_code_for_oauth_tokens + return True + + authenticating_user = ( + await get_authenticating_user_and_add_to_current_tenant_if_required( + third_party=ThirdPartyInfo( + third_party_user_id=user_info.third_party_user_id, + third_party_id=provider.id, + ), + email=None, + phone_number=None, + recipe_id=recipe_id, + user_context=user_context, + session=session, + tenant_id=tenant_id, + check_credentials_on_tenant=check_credentials_on_tenant, + ) + ) + + is_sign_up = authenticating_user is None + if authenticating_user is not None: + # This is a sign in. So before we proceed, we need to check if an email change + # is allowed since the email could have changed from the social provider's side. + # We do this check here and not in the recipe function cause we want to keep the + # recipe function checks to a minimum so that the dev has complete control of + # what they can do. + + # The is_email_change_allowed and pre_auth_checks functions take an is_verified boolean. + # Now, even though we already have that from the input, that's just what the provider says. + # If the provider says that the email is NOT verified, it could have been that the email + # is verified on the user's account via supertokens on a previous sign in / up. + # So we just check that as well before calling is_email_change_allowed + + assert authenticating_user.login_method is not None + recipe_user_id = authenticating_user.login_method.recipe_user_id + + if ( + not email_info.is_verified + and EmailVerificationRecipe.get_instance_optional() is not None + ): + email_info.is_verified = await is_email_verified( + recipe_user_id, + email_info.id, + user_context, ) - email = user_info.email.id if user_info.email is not None else None - email_verified = ( - user_info.email.is_verified if user_info.email is not None else None + pre_auth_checks_result = await pre_auth_checks( + authenticating_account_info=AccountInfoWithRecipeId( + recipe_id=recipe_id, + email=email_info.id, + third_party=ThirdPartyInfo( + third_party_user_id=user_info.third_party_user_id, + third_party_id=provider.id, + ), + ), + authenticating_user=( + authenticating_user.user if authenticating_user else None + ), + factor_ids=["thirdparty"], + is_sign_up=is_sign_up, + is_verified=email_info.is_verified, + sign_in_verifies_login_method=email_info.is_verified, + skip_session_user_update_in_core=False, + tenant_id=tenant_id, + user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) - if email is None: - return SignInUpPostNoEmailGivenByProviderResponse() + + if not isinstance(pre_auth_checks_result, OkResponse): + if isinstance(pre_auth_checks_result, SignUpNotAllowedResponse): + reason = error_code_map["SIGN_UP_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + if isinstance(pre_auth_checks_result, SignInNotAllowedResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[pre_auth_checks_result.reason] + return SignInUpNotAllowed(reason=reason) signinup_response = await api_options.recipe_implementation.sign_in_up( third_party_id=provider.id, third_party_user_id=user_info.third_party_user_id, - email=email, + email=email_info.id, + is_verified=email_info.is_verified, oauth_tokens=oauth_tokens_to_use, raw_user_info_from_provider=user_info.raw_user_info_from_provider, + session=session, tenant_id=tenant_id, user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) - if email_verified: - ev_instance = EmailVerificationRecipe.get_instance_optional() - if ev_instance is not None: - token_response = await ev_instance.recipe_implementation.create_email_verification_token( - tenant_id=tenant_id, - user_id=signinup_response.user.user_id, - email=signinup_response.user.email, - user_context=user_context, - ) + if isinstance(signinup_response, SignInUpNotAllowed): + return signinup_response - if isinstance(token_response, CreateEmailVerificationTokenOkResult): - await ev_instance.recipe_implementation.verify_email_using_token( - token=token_response.token, - tenant_id=tenant_id, - user_context=user_context, - ) + if not isinstance(signinup_response, SignInUpOkResult): + reason_dict = error_code_map["LINKING_TO_SESSION_USER_FAILED"] + assert isinstance(reason_dict, Dict) + reason = reason_dict[signinup_response.reason] + return SignInUpNotAllowed(reason=reason) - user = signinup_response.user - session = await create_new_session( + post_auth_checks_result = await post_auth_checks( + factor_id="thirdparty", + is_sign_up=is_sign_up, + authenticated_user=signinup_response.user, + recipe_user_id=signinup_response.recipe_user_id, request=api_options.request, tenant_id=tenant_id, - user_id=user.user_id, user_context=user_context, + session=session, ) + if not isinstance(post_auth_checks_result, PostAuthChecksOkResponse): + reason = error_code_map["SIGN_IN_NOT_ALLOWED"] + assert isinstance(reason, str) + return SignInUpNotAllowed(reason) + return SignInUpPostOkResult( - created_new_user=signinup_response.created_new_user, - user=user, - session=session, + created_new_recipe_user=signinup_response.created_new_recipe_user, + user=post_auth_checks_result.user, + session=post_auth_checks_result.session, oauth_tokens=oauth_tokens_to_use, raw_user_info_from_provider=user_info.raw_user_info_from_provider, ) diff --git a/supertokens_python/recipe/thirdparty/api/signinup.py b/supertokens_python/recipe/thirdparty/api/signinup.py index 60937ea49..bcd561a0c 100644 --- a/supertokens_python/recipe/thirdparty/api/signinup.py +++ b/supertokens_python/recipe/thirdparty/api/signinup.py @@ -13,14 +13,19 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional +from supertokens_python.recipe.thirdparty.interfaces import SignInUpPostOkResult from supertokens_python.recipe.thirdparty.provider import RedirectUriInfo if TYPE_CHECKING: from supertokens_python.recipe.thirdparty.interfaces import APIOptions, APIInterface from supertokens_python.exceptions import raise_bad_input_exception, BadInputError -from supertokens_python.utils import send_200_response +from supertokens_python.utils import ( + get_backwards_compatible_user_info, + get_normalised_should_try_linking_with_session_user_flag, + send_200_response, +) async def handle_sign_in_up_api( @@ -29,6 +34,8 @@ async def handle_sign_in_up_api( api_options: APIOptions, user_context: Dict[str, Any], ): + from supertokens_python.auth_utils import load_session_in_auth_api_if_needed + if api_implementation.disable_sign_in_up_post: return None @@ -71,8 +78,9 @@ async def handle_sign_in_up_api( provider = provider_response + redirect_uri_info_parsed: Optional[RedirectUriInfo] = None if redirect_uri_info is not None: - redirect_uri_info = RedirectUriInfo( + redirect_uri_info_parsed = RedirectUriInfo( redirect_uri_on_provider_dashboard=redirect_uri_info.get( "redirectURIOnProviderDashboard" ), @@ -80,12 +88,43 @@ async def handle_sign_in_up_api( pkce_code_verifier=redirect_uri_info.get("pkceCodeVerifier"), ) + should_try_linking_with_session_user = ( + get_normalised_should_try_linking_with_session_user_flag( + api_options.request, body + ) + ) + + session = await load_session_in_auth_api_if_needed( + api_options.request, should_try_linking_with_session_user, user_context + ) + + if session is not None: + tenant_id = session.get_tenant_id() + result = await api_implementation.sign_in_up_post( provider=provider, - redirect_uri_info=redirect_uri_info, + redirect_uri_info=redirect_uri_info_parsed, oauth_tokens=oauth_tokens, tenant_id=tenant_id, api_options=api_options, user_context=user_context, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, ) + + if isinstance(result, SignInUpPostOkResult): + return send_200_response( + { + "status": "OK", + **get_backwards_compatible_user_info( + req=api_options.request, + user_info=result.user, + session_container=result.session, + created_new_recipe_user=result.created_new_recipe_user, + user_context=user_context, + ), + }, + api_options.response, + ) + return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/thirdparty/asyncio/__init__.py b/supertokens_python/recipe/thirdparty/asyncio/__init__.py index 0e70c5166..b374e041b 100644 --- a/supertokens_python/recipe/thirdparty/asyncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/asyncio/__init__.py @@ -12,66 +12,43 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe -from ..types import User - - -async def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - if user_context is None: - user_context = {} - return await ThirdPartyRecipe.get_instance().recipe_implementation.get_user_by_id( - user_id, user_context - ) - - -async def get_users_by_email( - tenant_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -) -> List[User]: - if user_context is None: - user_context = {} - return ( - await ThirdPartyRecipe.get_instance().recipe_implementation.get_users_by_email( - email, tenant_id, user_context - ) - ) - - -async def get_user_by_third_party_info( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): - if user_context is None: - user_context = {} - return await ThirdPartyRecipe.get_instance().recipe_implementation.get_user_by_thirdparty_info( - third_party_id, - third_party_user_id, - tenant_id, - user_context, - ) - async def manually_create_or_update_user( tenant_id: str, third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, -): +) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, +]: if user_context is None: user_context = {} return await ThirdPartyRecipe.get_instance().recipe_implementation.manually_create_or_update_user( - third_party_id, - third_party_user_id, - email, - tenant_id, - user_context, + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + email=email, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + user_context=user_context, + should_try_linking_with_session_user=session is not None, ) diff --git a/supertokens_python/recipe/thirdparty/interfaces.py b/supertokens_python/recipe/thirdparty/interfaces.py index b4c5c6bee..9a943c3aa 100644 --- a/supertokens_python/recipe/thirdparty/interfaces.py +++ b/supertokens_python/recipe/thirdparty/interfaces.py @@ -16,15 +16,16 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from ...types import APIResponse, GeneralErrorResponse +from ...types import APIResponse, User, GeneralErrorResponse, RecipeUserId from .provider import Provider, ProviderInput, RedirectUriInfo if TYPE_CHECKING: from supertokens_python.framework import BaseRequest, BaseResponse from supertokens_python.recipe.session import SessionContainer from supertokens_python.supertokens import AppInfo + from supertokens_python.auth_utils import LinkingToSessionUserFailedError - from .types import User, RawUserInfoFromProvider + from .types import RawUserInfoFromProvider from .utils import ThirdPartyConfig @@ -32,24 +33,28 @@ class SignInUpOkResult: def __init__( self, user: User, - created_new_user: bool, + recipe_user_id: RecipeUserId, + created_new_recipe_user: bool, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, ): self.user = user - self.created_new_user = created_new_user + self.created_new_recipe_user = created_new_recipe_user self.oauth_tokens = oauth_tokens self.raw_user_info_from_provider = raw_user_info_from_provider + self.recipe_user_id = recipe_user_id class ManuallyCreateOrUpdateUserOkResult: def __init__( self, user: User, - created_new_user: bool, + recipe_user_id: RecipeUserId, + created_new_recipe_user: bool, ): self.user = user - self.created_new_user = created_new_user + self.recipe_user_id = recipe_user_id + self.created_new_recipe_user = created_new_recipe_user class GetProviderOkResult: @@ -57,30 +62,24 @@ def __init__(self, provider: Provider): self.provider = provider -class RecipeInterface(ABC): - def __init__(self): - pass +class SignInUpNotAllowed(APIResponse): + status: str = "SIGN_IN_UP_NOT_ALLOWED" + reason: str - @abstractmethod - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - pass + def __init__(self, reason: str): + self.reason = reason - @abstractmethod - async def get_users_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> List[User]: - pass + def to_json(self) -> Dict[str, Any]: + return {"status": self.status, "reason": self.reason} - @abstractmethod - async def get_user_by_thirdparty_info( - self, - third_party_id: str, - third_party_user_id: str, - tenant_id: str, - user_context: Dict[str, Any], - ) -> Union[User, None]: + +class EmailChangeNotAllowedError: + def __init__(self, reason: str): + self.reason = reason + + +class RecipeInterface(ABC): + def __init__(self): pass @abstractmethod @@ -89,9 +88,17 @@ async def manually_create_or_update_user( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], - ) -> ManuallyCreateOrUpdateUserOkResult: + ) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, + ]: pass @abstractmethod @@ -100,11 +107,14 @@ async def sign_in_up( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], - ) -> SignInUpOkResult: + ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: pass @abstractmethod @@ -144,13 +154,13 @@ class SignInUpPostOkResult(APIResponse): def __init__( self, user: User, - created_new_user: bool, + created_new_recipe_user: bool, session: SessionContainer, oauth_tokens: Dict[str, Any], raw_user_info_from_provider: RawUserInfoFromProvider, ): self.user = user - self.created_new_user = created_new_user + self.created_new_recipe_user = created_new_recipe_user self.session = session self.oauth_tokens = oauth_tokens self.raw_user_info_from_provider = raw_user_info_from_provider @@ -158,16 +168,8 @@ def __init__( def to_json(self) -> Dict[str, Any]: return { "status": self.status, - "user": { - "id": self.user.user_id, - "email": self.user.email, - "timeJoined": self.user.time_joined, - "thirdParty": { - "id": self.user.third_party_info.id, - "userId": self.user.third_party_info.user_id, - }, - }, - "createdNewUser": self.created_new_user, + "user": self.user.to_json(), + "createdNewRecipeUser": self.created_new_recipe_user, } @@ -217,12 +219,15 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Optional[RedirectUriInfo], oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ SignInUpPostOkResult, SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, GeneralErrorResponse, ]: pass diff --git a/supertokens_python/recipe/thirdparty/provider.py b/supertokens_python/recipe/thirdparty/provider.py index 568242ca4..44e426cd3 100644 --- a/supertokens_python/recipe/thirdparty/provider.py +++ b/supertokens_python/recipe/thirdparty/provider.py @@ -111,6 +111,17 @@ def to_json(self) -> Dict[str, Any]: return {k: v for k, v in res.items() if v is not None} + @staticmethod + def from_json(json: Dict[str, Any]) -> ProviderClientConfig: + return ProviderClientConfig( + client_id=json.get("clientId", ""), + client_secret=json.get("clientSecret", None), + client_type=json.get("clientType", None), + scope=json.get("scope", None), + force_pkce=json.get("forcePKCE", None), + additional_config=json.get("additionalConfig", None), + ) + class UserFields: def __init__( @@ -132,6 +143,16 @@ def to_json(self) -> Dict[str, Any]: return {k: v for k, v in res.items() if v is not None} + @staticmethod + def from_json(json: Optional[Dict[str, Any]]) -> Optional[UserFields]: + if json is None: + return None + return UserFields( + user_id=json.get("userId", None), + email=json.get("email", None), + email_verified=json.get("emailVerified", None), + ) + class UserInfoMap: def __init__( @@ -150,6 +171,17 @@ def to_json(self) -> Dict[str, Any]: res["fromUserInfoAPI"] = self.from_user_info_api.to_json() return res + @staticmethod + def from_json(json: Optional[Dict[str, Any]]) -> Optional[UserInfoMap]: + if json is None: + return None + return UserInfoMap( + from_id_token_payload=UserFields.from_json( + json.get("fromIdTokenPayload", None) + ), + from_user_info_api=UserFields.from_json(json.get("fromUserInfoAPI", None)), + ) + class CommonProviderConfig: def __init__( @@ -213,9 +245,9 @@ def to_json(self) -> Dict[str, Any]: "userInfoEndpointHeaders": self.user_info_endpoint_headers, "jwksURI": self.jwks_uri, "oidcDiscoveryEndpoint": self.oidc_discovery_endpoint, - "userInfoMap": self.user_info_map.to_json() - if self.user_info_map is not None - else None, + "userInfoMap": ( + self.user_info_map.to_json() if self.user_info_map is not None else None + ), "requireEmail": self.require_email, } @@ -362,6 +394,34 @@ def to_json(self) -> Dict[str, Any]: return d + @staticmethod + def from_json(json: Dict[str, Any]) -> ProviderConfig: + return ProviderConfig( + third_party_id=json.get("thirdPartyId", ""), + name=json.get("name", None), + clients=[ + ProviderClientConfig.from_json(c) for c in json.get("clients", []) + ], + authorization_endpoint=json.get("authorizationEndpoint", None), + authorization_endpoint_query_params=json.get( + "authorizationEndpointQueryParams", None + ), + token_endpoint=json.get("tokenEndpoint", None), + token_endpoint_body_params=json.get("tokenEndpointBodyParams", None), + user_info_endpoint=json.get("userInfoEndpoint", None), + user_info_endpoint_query_params=json.get( + "userInfoEndpointQueryParams", None + ), + user_info_endpoint_headers=json.get("userInfoEndpointHeaders", None), + jwks_uri=json.get("jwksURI", None), + oidc_discovery_endpoint=json.get("oidcDiscoveryEndpoint", None), + user_info_map=UserInfoMap.from_json(json.get("userInfoMap", None)), + require_email=json.get("requireEmail", None), + validate_id_token_payload=None, + generate_fake_email=None, + validate_access_token=None, + ) + class ProviderInput: def __init__( diff --git a/supertokens_python/recipe/thirdparty/providers/active_directory.py b/supertokens_python/recipe/thirdparty/providers/active_directory.py index 14e20e827..8ebb00bc1 100644 --- a/supertokens_python/recipe/thirdparty/providers/active_directory.py +++ b/supertokens_python/recipe/thirdparty/providers/active_directory.py @@ -30,20 +30,17 @@ async def get_config_for_client_type( config = await super().get_config_for_client_type(client_type, user_context) if ( - config.additional_config is None - or config.additional_config.get("directoryId") is None + config.additional_config is not None + and config.additional_config.get("directoryId") is not None ): - if not config.oidc_discovery_endpoint: - raise Exception( - "Please provide the directoryId in the additionalConfig of the Active Directory provider." - ) - else: config.oidc_discovery_endpoint = f"https://login.microsoftonline.com/{config.additional_config['directoryId']}/v2.0/.well-known/openid-configuration" - # The config could be coming from core where we didn't add the well-known previously - config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( - config.oidc_discovery_endpoint - ) + if config.oidc_discovery_endpoint is not None: + config.oidc_discovery_endpoint = ( + normalise_oidc_endpoint_to_include_well_known( + config.oidc_discovery_endpoint + ) + ) if config.scope is None: config.scope = ["openid", "email"] diff --git a/supertokens_python/recipe/thirdparty/providers/apple.py b/supertokens_python/recipe/thirdparty/providers/apple.py index 93aa7ebac..3bba95ef8 100644 --- a/supertokens_python/recipe/thirdparty/providers/apple.py +++ b/supertokens_python/recipe/thirdparty/providers/apple.py @@ -12,14 +12,17 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations +import json from re import sub from typing import Any, Dict, Optional from jwt import encode # type: ignore from time import time +from supertokens_python.recipe.thirdparty.types import UserInfo + from .custom import GenericProvider, NewProvider -from ..provider import Provider, ProviderConfigForClient, ProviderInput +from ..provider import Provider, ProviderConfigForClient, ProviderInput, RedirectUriInfo from .utils import ( get_actual_client_id_from_development_client_id, normalise_oidc_endpoint_to_include_well_known, @@ -76,6 +79,47 @@ async def _get_client_secret( # pylint: disable=no-self-use headers=headers, ) # type: ignore + async def exchange_auth_code_for_oauth_tokens( + self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any] + ) -> Dict[str, Any]: + response = await super().exchange_auth_code_for_oauth_tokens( + redirect_uri_info, user_context + ) + + user = redirect_uri_info.redirect_uri_query_params.get("user") + if user is not None: + if isinstance(user, str): + response["user"] = json.loads(user) + elif isinstance(user, dict): + response["user"] = user + + return response + + async def get_user_info( + self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] + ) -> UserInfo: + response = await super().get_user_info(oauth_tokens, user_context) + user = oauth_tokens.get("user") + + user_dict: Dict[str, Any] = {} + + if user is not None: + if isinstance(user, str): + user_dict = json.loads(user) + elif isinstance(user, dict): + user_dict = user + else: + return response + + if response.raw_user_info_from_provider.from_id_token_payload is None: + response.raw_user_info_from_provider.from_id_token_payload = {} + + response.raw_user_info_from_provider.from_id_token_payload[ + "user" + ] = user_dict + + return response + def Apple(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin if not input.config.name: diff --git a/supertokens_python/recipe/thirdparty/providers/custom.py b/supertokens_python/recipe/thirdparty/providers/custom.py index 9b6a355de..4d65ef894 100644 --- a/supertokens_python/recipe/thirdparty/providers/custom.py +++ b/supertokens_python/recipe/thirdparty/providers/custom.py @@ -437,11 +437,11 @@ async def get_user_info( def NewProvider( - input: ProviderInput, # pylint: disable=redefined-builtin + input_: ProviderInput, base_class: Callable[[ProviderConfig], Provider] = GenericProvider, ) -> Provider: - provider_instance = base_class(input.config) - if input.override is not None: - provider_instance = input.override(provider_instance) + provider_instance = base_class(input_.config) + if input_.override is not None: + provider_instance = input_.override(provider_instance) return provider_instance diff --git a/supertokens_python/recipe/thirdparty/providers/discord.py b/supertokens_python/recipe/thirdparty/providers/discord.py index 677f38032..d301cbcb5 100644 --- a/supertokens_python/recipe/thirdparty/providers/discord.py +++ b/supertokens_python/recipe/thirdparty/providers/discord.py @@ -45,7 +45,7 @@ def Discord(input: ProviderInput) -> Provider: # pylint: disable=redefined-buil input.config.name = "Discord" if not input.config.authorization_endpoint: - input.config.authorization_endpoint = "https://discord.com/api/oauth2/authorize" + input.config.authorization_endpoint = "https://discord.com/oauth2/authorize" if not input.config.token_endpoint: input.config.token_endpoint = "https://discord.com/api/oauth2/token" diff --git a/supertokens_python/recipe/thirdparty/providers/facebook.py b/supertokens_python/recipe/thirdparty/providers/facebook.py index d6ac1a890..327be5181 100644 --- a/supertokens_python/recipe/thirdparty/providers/facebook.py +++ b/supertokens_python/recipe/thirdparty/providers/facebook.py @@ -44,12 +44,55 @@ async def get_config_for_client_type( async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] ) -> UserInfo: + fields_permission_map = { + "public_profile": [ + "first_name", + "last_name", + "middle_name", + "name", + "name_format", + "picture", + "short_name", + ], + "email": ["id", "email"], + "user_birthday": ["birthday"], + "user_videos": ["videos"], + "user_posts": ["posts"], + "user_photos": ["photos"], + "user_location": ["location"], + "user_link": ["link"], + "user_likes": ["likes"], + "user_hometown": ["hometown"], + "user_gender": ["gender"], + "user_friends": ["friends"], + "user_age_range": ["age_range"], + } + scope_values = self.config.scope + + fields = ( + ",".join( + [ + field + for scope in scope_values + for field in fields_permission_map.get(scope, []) + ] + ) + if scope_values + else "id,email" + ) + self.config.user_info_endpoint_query_params = { "access_token": str(oauth_tokens["access_token"]), - "fields": "id,email", + "fields": fields, "format": "json", **(self.config.user_info_endpoint_query_params or {}), } + + self.config.user_info_endpoint_headers = { + **(self.config.user_info_endpoint_headers or {}), + "Authorization": None, + } + return await super().get_user_info(oauth_tokens, user_context) diff --git a/supertokens_python/recipe/thirdparty/providers/gitlab.py b/supertokens_python/recipe/thirdparty/providers/gitlab.py index de768067c..bdbcac750 100644 --- a/supertokens_python/recipe/thirdparty/providers/gitlab.py +++ b/supertokens_python/recipe/thirdparty/providers/gitlab.py @@ -47,9 +47,10 @@ async def get_config_for_client_type( oidc_domain.get_as_string_dangerous() + oidc_path.get_as_string_dangerous() ) - - if not config.oidc_discovery_endpoint: - raise Exception("should never come here") + elif config.oidc_discovery_endpoint is None: + config.oidc_discovery_endpoint = ( + "https://gitlab.com/.well-known/openid-configuration" + ) # The config could be coming from core where we didn't add the well-known previously config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( diff --git a/supertokens_python/recipe/thirdparty/providers/okta.py b/supertokens_python/recipe/thirdparty/providers/okta.py index 9950afe39..56e180029 100644 --- a/supertokens_python/recipe/thirdparty/providers/okta.py +++ b/supertokens_python/recipe/thirdparty/providers/okta.py @@ -32,14 +32,9 @@ async def get_config_for_client_type( config = await super().get_config_for_client_type(client_type, user_context) if ( - config.additional_config is None - or config.additional_config.get("oktaDomain") is None + config.additional_config is not None + and config.additional_config.get("oktaDomain") is not None ): - if not config.oidc_discovery_endpoint: - raise Exception( - "Please provide the oktaDomain in the additionalConfig of the Okta provider." - ) - else: okta_domain = config.additional_config["oktaDomain"] oidc_domain = NormalisedURLDomain(okta_domain) oidc_path = NormalisedURLPath("/.well-known/openid-configuration") @@ -48,13 +43,12 @@ async def get_config_for_client_type( + oidc_path.get_as_string_dangerous() ) - if not config.oidc_discovery_endpoint: - raise Exception("should never happen") - - # The config could be coming from core where we didn't add the well-known previously - config.oidc_discovery_endpoint = normalise_oidc_endpoint_to_include_well_known( - config.oidc_discovery_endpoint - ) + if config.oidc_discovery_endpoint is not None: + config.oidc_discovery_endpoint = ( + normalise_oidc_endpoint_to_include_well_known( + config.oidc_discovery_endpoint + ) + ) if config.scope is None: config.scope = ["openid", "email"] diff --git a/supertokens_python/recipe/thirdparty/providers/utils.py b/supertokens_python/recipe/thirdparty/providers/utils.py index 9f9aaf80b..1f274461a 100644 --- a/supertokens_python/recipe/thirdparty/providers/utils.py +++ b/supertokens_python/recipe/thirdparty/providers/utils.py @@ -48,7 +48,7 @@ async def do_get_request( async def do_post_request( url: str, - body_params: Optional[Dict[str, str]] = None, + body_params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Tuple[int, Dict[str, Any]]: if body_params is None: @@ -64,7 +64,10 @@ async def do_post_request( log_debug_message( "Received response with status %s and body %s", res.status_code, res.text ) - return res.status_code, res.json() + try: + return res.status_code, res.json() + except Exception: + return res.status_code, {"message": res.text} def normalise_oidc_endpoint_to_include_well_known(url: str) -> str: diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index 5c234c637..9a7ca219f 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -23,7 +23,6 @@ from .api.implementation import APIImplementation from .interfaces import APIInterface, APIOptions, RecipeInterface from .recipe_implementation import RecipeImplementation -from ..emailverification.interfaces import GetEmailForUserIdOkResult, UnknownUserIdError from ...post_init_callbacks import PostSTInitCallbacks if TYPE_CHECKING: @@ -33,7 +32,6 @@ from .utils import SignInAndUpFeature, InputOverrideConfig from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from .api import ( @@ -81,13 +79,10 @@ def __init__( ) def callback(): - ev_recipe = EmailVerificationRecipe.get_instance_optional() - if ev_recipe: - ev_recipe.add_get_email_for_user_id_func(self.get_email_for_user_id) - mt_recipe = MultitenancyRecipe.get_instance_optional() if mt_recipe: mt_recipe.static_third_party_providers = self.providers + mt_recipe.all_available_first_factors.append("thirdparty") PostSTInitCallbacks.add_post_init_callback(callback) @@ -205,12 +200,3 @@ def reset(): ThirdPartyRecipe.__instance = None # instance functions below............... - - async def get_email_for_user_id(self, user_id: str, user_context: Dict[str, Any]): - user_info = await self.recipe_implementation.get_user_by_id( - user_id, user_context - ) - if user_info is not None: - return GetEmailForUserIdOkResult(user_info.email) - - return UnknownUserIdError() diff --git a/supertokens_python/recipe/thirdparty/recipe_implementation.py b/supertokens_python/recipe/thirdparty/recipe_implementation.py index 328fc0129..3a0a2824d 100644 --- a/supertokens_python/recipe/thirdparty/recipe_implementation.py +++ b/supertokens_python/recipe/thirdparty/recipe_implementation.py @@ -14,25 +14,34 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.thirdparty.provider import ProviderInput from supertokens_python.recipe.thirdparty.providers.config_utils import ( find_and_create_provider_instance, merge_providers_from_core_and_static, ) +from supertokens_python.types import AccountInfo, User, RecipeUserId if TYPE_CHECKING: from supertokens_python.querier import Querier + from supertokens_python.auth_utils import ( + LinkingToSessionUserFailedError, + ) from .interfaces import ( + EmailChangeNotAllowedError, ManuallyCreateOrUpdateUserOkResult, RecipeInterface, + SignInUpNotAllowed, SignInUpOkResult, ) -from .types import RawUserInfoFromProvider, ThirdPartyInfo, User +from .types import RawUserInfoFromProvider, ThirdPartyInfo class RecipeImplementation(RecipeInterface): @@ -41,149 +50,146 @@ def __init__(self, querier: Querier, providers: List[ProviderInput]): self.querier = querier self.providers = providers - async def get_user_by_id( - self, user_id: str, user_context: Dict[str, Any] - ) -> Union[User, None]: - params = {"userId": user_id} - response = await self.querier.send_get_request( - NormalisedURLPath("/recipe/user"), - params, - user_context=user_context, - ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ) - return None - - async def get_users_by_email( - self, email: str, tenant_id: str, user_context: Dict[str, Any] - ) -> List[User]: - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/users/by-email"), - {"email": email}, - user_context=user_context, - ) - users: List[User] = [] - users_list: List[Dict[str, Any]] = ( - response["users"] if "users" in response else [] - ) - for user in users_list: - users.append( - User( - user["id"], - user["email"], - user["timeJoined"], - user["tenantIds"], - ThirdPartyInfo( - user["thirdParty"]["userId"], user["thirdParty"]["id"] - ), - ) - ) - return users - - async def get_user_by_thirdparty_info( + async def sign_in_up( self, third_party_id: str, third_party_user_id: str, + email: str, + is_verified: bool, + oauth_tokens: Dict[str, Any], + raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], - ) -> Union[User, None]: - params = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - } - response = await self.querier.send_get_request( - NormalisedURLPath(f"{tenant_id}/recipe/user"), - params, + ) -> Union[SignInUpOkResult, SignInUpNotAllowed, LinkingToSessionUserFailedError]: + response = await self.manually_create_or_update_user( + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + email=email, + tenant_id=tenant_id, + is_verified=is_verified, + session=session, + should_try_linking_with_session_user=should_try_linking_with_session_user, user_context=user_context, ) - if "status" in response and response["status"] == "OK": - return User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), + + if isinstance(response, EmailChangeNotAllowedError): + return SignInUpNotAllowed( + "Cannot sign in / up because new email cannot be applied to existing account. Please contact support. (ERR_CODE_005)" + if response.reason + == "Email already associated with another primary user." + else "Cannot sign in / up because new email cannot be applied to existing account. Please contact support. (ERR_CODE_024)" ) - return None - async def sign_in_up( + if isinstance(response, ManuallyCreateOrUpdateUserOkResult): + return SignInUpOkResult( + user=response.user, + recipe_user_id=response.recipe_user_id, + created_new_recipe_user=response.created_new_recipe_user, + oauth_tokens=oauth_tokens, + raw_user_info_from_provider=raw_user_info_from_provider, + ) + + return response + + async def manually_create_or_update_user( self, third_party_id: str, third_party_user_id: str, email: str, - oauth_tokens: Dict[str, Any], - raw_user_info_from_provider: RawUserInfoFromProvider, + is_verified: bool, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, user_context: Dict[str, Any], - ) -> SignInUpOkResult: - data = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - "email": {"id": email}, - } - response = await self.querier.send_post_request( - NormalisedURLPath(f"{tenant_id}/recipe/signinup"), - data, - user_context=user_context, + ) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, + ]: + from supertokens_python.auth_utils import ( + link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info, ) - return SignInUpOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], + + account_linking = AccountLinkingRecipe.get_instance() + users = await list_users_by_account_info( + tenant_id, + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, ), ), - response["createdNewUser"], - oauth_tokens, - raw_user_info_from_provider, + False, + user_context, ) - async def manually_create_or_update_user( - self, - third_party_id: str, - third_party_user_id: str, - email: str, - tenant_id: str, - user_context: Dict[str, Any], - ) -> ManuallyCreateOrUpdateUserOkResult: - data = { - "thirdPartyId": third_party_id, - "thirdPartyUserId": third_party_user_id, - "email": {"id": email}, - } + user = users[0] if users else None + if user is not None: + is_email_change_allowed = await account_linking.is_email_change_allowed( + user=user, + is_verified=is_verified, + new_email=email, + session=session, + user_context=user_context, + ) + if not is_email_change_allowed.allowed: + reason = ( + "Email already associated with another primary user." + if is_email_change_allowed.reason == "PRIMARY_USER_CONFLICT" + else "New email cannot be applied to existing account because of account takeover risks." + ) + return EmailChangeNotAllowedError(reason) + response = await self.querier.send_post_request( NormalisedURLPath(f"{tenant_id}/recipe/signinup"), - data, + { + "thirdPartyId": third_party_id, + "thirdPartyUserId": third_party_user_id, + "email": {"id": email, "isVerified": is_verified}, + }, user_context=user_context, ) + + if response["status"] == "EMAIL_CHANGE_NOT_ALLOWED_ERROR": + return EmailChangeNotAllowedError(response["reason"]) + + # status is OK + + user = User.from_json( + response["user"], + ) + recipe_user_id = RecipeUserId(response["recipeUserId"]) + + await account_linking.verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=user_context, + ) + + # Fetch updated user + user = await get_user(recipe_user_id.get_as_string(), user_context) + + assert user is not None + + link_result = await link_to_session_if_provided_else_create_primary_user_id_or_link_by_account_info( + tenant_id=tenant_id, + input_user=user, + recipe_user_id=recipe_user_id, + session=session, + user_context=user_context, + should_try_linking_with_session_user=should_try_linking_with_session_user, + ) + + if link_result.status != "OK": + return link_result + return ManuallyCreateOrUpdateUserOkResult( - User( - response["user"]["id"], - response["user"]["email"], - response["user"]["timeJoined"], - response["user"]["tenantIds"], - ThirdPartyInfo( - response["user"]["thirdParty"]["userId"], - response["user"]["thirdParty"]["id"], - ), - ), - response["createdNewUser"], + user=link_result.user, + recipe_user_id=recipe_user_id, + created_new_recipe_user=response["createdNewUser"], ) async def get_provider( @@ -203,7 +209,7 @@ async def get_provider( raise Exception("Tenant not found") merged_providers = merge_providers_from_core_and_static( - provider_configs_from_core=tenant_config.third_party.providers, + provider_configs_from_core=tenant_config.third_party_providers, provider_inputs_from_static=self.providers, include_all_providers=tenant_id == DEFAULT_TENANT_ID, ) diff --git a/supertokens_python/recipe/thirdparty/syncio/__init__.py b/supertokens_python/recipe/thirdparty/syncio/__init__.py index 218c6862e..4481c8131 100644 --- a/supertokens_python/recipe/thirdparty/syncio/__init__.py +++ b/supertokens_python/recipe/thirdparty/syncio/__init__.py @@ -11,46 +11,16 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union from supertokens_python.async_to_sync_wrapper import sync - -from ..types import User - - -def get_user_by_id( - user_id: str, user_context: Union[None, Dict[str, Any]] = None -) -> Union[User, None]: - from supertokens_python.recipe.thirdparty.asyncio import get_user_by_id - - return sync(get_user_by_id(user_id, user_context)) - - -def get_users_by_email( - tenant_id: str, - email: str, - user_context: Union[None, Dict[str, Any]] = None, -) -> List[User]: - from supertokens_python.recipe.thirdparty.asyncio import get_users_by_email - - return sync(get_users_by_email(tenant_id, email, user_context)) - - -def get_user_by_third_party_info( - tenant_id: str, - third_party_id: str, - third_party_user_id: str, - user_context: Union[None, Dict[str, Any]] = None, -): - from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_third_party_info, - ) - - return sync( - get_user_by_third_party_info( - tenant_id, third_party_id, third_party_user_id, user_context - ) - ) +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) def manually_create_or_update_user( @@ -58,15 +28,28 @@ def manually_create_or_update_user( third_party_id: str, third_party_user_id: str, email: str, + is_verified: bool, + session: Optional[SessionContainer] = None, user_context: Union[None, Dict[str, Any]] = None, -): +) -> Union[ + ManuallyCreateOrUpdateUserOkResult, + LinkingToSessionUserFailedError, + SignInUpNotAllowed, + EmailChangeNotAllowedError, +]: from supertokens_python.recipe.thirdparty.asyncio import ( manually_create_or_update_user, ) return sync( manually_create_or_update_user( - tenant_id, third_party_id, third_party_user_id, email, user_context + email=email, + is_verified=is_verified, + session=session, + tenant_id=tenant_id, + third_party_id=third_party_id, + third_party_user_id=third_party_user_id, + user_context=user_context, ) ) diff --git a/supertokens_python/recipe/thirdparty/types.py b/supertokens_python/recipe/thirdparty/types.py index 5d18c6b3a..3e8ec7037 100644 --- a/supertokens_python/recipe/thirdparty/types.py +++ b/supertokens_python/recipe/thirdparty/types.py @@ -11,10 +11,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Union, Optional +from __future__ import annotations + +from typing import Any, Callable, Dict, Union, Optional, TYPE_CHECKING from supertokens_python.framework.request import BaseRequest +if TYPE_CHECKING: + from supertokens_python.types import User + class ThirdPartyInfo: def __init__(self, third_party_user_id: str, third_party_id: str): @@ -28,6 +33,9 @@ def __eq__(self, other: object) -> bool: and self.id == other.id ) + def to_json(self) -> Dict[str, Any]: + return {"userId": self.user_id, "id": self.id} + class RawUserInfoFromProvider: def __init__( @@ -38,31 +46,11 @@ def __init__( self.from_id_token_payload = from_id_token_payload self.from_user_info_api = from_user_info_api - -class User: - def __init__( - self, - user_id: str, - email: str, - time_joined: int, - tenant_ids: List[str], - third_party_info: ThirdPartyInfo, - ): - self.user_id: str = user_id - self.email: str = email - self.time_joined: int = time_joined - self.tenant_ids = tenant_ids - self.third_party_info: ThirdPartyInfo = third_party_info - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, self.__class__) - and self.user_id == other.user_id - and self.email == other.email - and self.time_joined == other.time_joined - and self.tenant_ids == other.tenant_ids - and self.third_party_info == other.third_party_info - ) + def to_json(self) -> Dict[str, Any]: + return { + "fromIdTokenPayload": self.from_id_token_payload, + "fromUserInfoApi": self.from_user_info_api, + } class UserInfoEmail: @@ -70,6 +58,9 @@ def __init__(self, email: str, is_verified: bool): self.id: str = email self.is_verified: bool = is_verified + def to_json(self) -> Dict[str, Any]: + return {"id": self.id, "isVerified": self.is_verified} + class UserInfo: def __init__( @@ -84,6 +75,13 @@ def __init__( raw_user_info_from_provider or RawUserInfoFromProvider({}, {}) ) + def to_json(self) -> Dict[str, Any]: + return { + "thirdPartyUserId": self.third_party_user_id, + "email": self.email.to_json() if self.email is not None else None, + "rawUserInfoFromProvider": self.raw_user_info_from_provider.to_json(), + } + class AccessTokenAPI: def __init__(self, url: str, params: Dict[str, str]): @@ -105,11 +103,5 @@ def __init__(self, user: User, is_new_user: bool): self.is_new_user = is_new_user -class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users = users - self.next_pagination_token = next_pagination_token - - class ThirdPartyIngredients: pass diff --git a/supertokens_python/recipe/totp/__init__.py b/supertokens_python/recipe/totp/__init__.py new file mode 100644 index 000000000..f89944688 --- /dev/null +++ b/supertokens_python/recipe/totp/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Union + +from supertokens_python.recipe.totp.types import TOTPConfig + +from .recipe import TOTPRecipe + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo + + from ...recipe_module import RecipeModule + + +def init( + config: Union[TOTPConfig, None] = None, +) -> Callable[[AppInfo], RecipeModule]: + return TOTPRecipe.init( + config=config, + ) diff --git a/supertokens_python/recipe/totp/api/create_device.py b/supertokens_python/recipe/totp/api/create_device.py new file mode 100644 index 000000000..7a7ba2ef2 --- /dev/null +++ b/supertokens_python/recipe/totp/api/create_device.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union + +from supertokens_python.framework import BaseResponse + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + +from supertokens_python.exceptions import raise_bad_input_exception +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + + +async def handle_create_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_create_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + device_name = body.get("deviceName") + + if device_name is not None and not isinstance(device_name, str): + raise_bad_input_exception("deviceName must be a string") + + response = await api_implementation.create_device_post( + device_name, api_options, session, user_context + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/implementation.py b/supertokens_python/recipe/totp/api/implementation.py new file mode 100644 index 000000000..3713a7c4d --- /dev/null +++ b/supertokens_python/recipe/totp/api/implementation.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Dict, Any, Union +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import UnauthorisedError +from supertokens_python.recipe.multifactorauth.asyncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from ..interfaces import APIInterface, APIOptions +from ..types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, + UnknownUserIdError, +) +from supertokens_python.types import GeneralErrorResponse + + +class APIImplementation(APIInterface): + async def create_device_post( + self, + device_name: Union[str, None], + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, GeneralErrorResponse]: + user_id = session.get_user_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + await assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session, "totp", user_context + ) + + create_device_res = await options.recipe_implementation.create_device( + user_id=user_id, + user_identifier_info=None, + device_name=device_name, + skew=None, + period=None, + user_context=user_context, + ) + + if isinstance(create_device_res, UnknownUserIdError): + raise UnauthorisedError("Session user not found") + return create_device_res + + async def list_devices_get( + self, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ListDevicesOkResult, GeneralErrorResponse]: + user_id = session.get_user_id() + + return await options.recipe_implementation.list_devices( + user_id=user_id, + user_context=user_context, + ) + + async def remove_device_post( + self, + device_name: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[RemoveDeviceOkResult, GeneralErrorResponse]: + user_id = session.get_user_id() + + return await options.recipe_implementation.remove_device( + user_id=user_id, + device_name=device_name, + user_context=user_context, + ) + + async def verify_device_post( + self, + device_name: str, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + user_id = session.get_user_id() + tenant_id = session.get_tenant_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + await assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session, "totp", user_context + ) + + res = await options.recipe_implementation.verify_device( + tenant_id=tenant_id, + user_id=user_id, + device_name=device_name, + totp=totp, + user_context=user_context, + ) + + if isinstance(res, VerifyDeviceOkResult): + await mfa_instance.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id="totp", + user_context=user_context, + ) + + return res + + async def verify_totp_post( + self, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + user_id = session.get_user_id() + tenant_id = session.get_tenant_id() + + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is None: + raise Exception("should never come here") + + res = await options.recipe_implementation.verify_totp( + tenant_id=tenant_id, + user_id=user_id, + totp=totp, + user_context=user_context, + ) + + if isinstance(res, VerifyTOTPOkResult): + await mfa_instance.recipe_implementation.mark_factor_as_complete_in_session( + session=session, + factor_id="totp", + user_context=user_context, + ) + + return res diff --git a/supertokens_python/recipe/totp/api/list_devices.py b/supertokens_python/recipe/totp/api/list_devices.py new file mode 100644 index 000000000..3030fa3a3 --- /dev/null +++ b/supertokens_python/recipe/totp/api/list_devices.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_list_devices_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_list_devices_get: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + response = await api_implementation.list_devices_get( + api_options, session, user_context + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/remove_device.py b/supertokens_python/recipe/totp/api/remove_device.py new file mode 100644 index 000000000..128424b94 --- /dev/null +++ b/supertokens_python/recipe/totp/api/remove_device.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_remove_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_remove_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + device_name = body.get("deviceName") + + if device_name is None or not isinstance(device_name, str) or len(device_name) == 0: + raise Exception("deviceName is required and must be a non-empty string") + + response = await api_implementation.remove_device_post( + device_name=device_name, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/verify_device.py b/supertokens_python/recipe/totp/api/verify_device.py new file mode 100644 index 000000000..b57e091b1 --- /dev/null +++ b/supertokens_python/recipe/totp/api/verify_device.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_verify_device_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_verify_device_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + device_name = body.get("deviceName") + totp = body.get("totp") + + if device_name is None or not isinstance(device_name, str): + raise Exception("deviceName is required and must be a string") + + if totp is None or not isinstance(totp, str): + raise Exception("totp is required and must be a string") + + response = await api_implementation.verify_device_post( + device_name=device_name, + totp=totp, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/api/verify_totp.py b/supertokens_python/recipe/totp/api/verify_totp.py new file mode 100644 index 000000000..1dbe91218 --- /dev/null +++ b/supertokens_python/recipe/totp/api/verify_totp.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Union +from supertokens_python.exceptions import raise_bad_input_exception + +from supertokens_python.framework import BaseResponse +from supertokens_python.utils import send_200_response +from supertokens_python.recipe.session.asyncio import get_session + +if TYPE_CHECKING: + from supertokens_python.recipe.totp.interfaces import APIOptions, APIInterface + + +async def handle_verify_totp_api( + tenant_id: str, + api_implementation: APIInterface, + api_options: APIOptions, + user_context: Dict[str, Any], +) -> Union[BaseResponse, None]: + if api_implementation.disable_verify_totp_post: + return None + + session = await get_session( + api_options.request, + override_global_claim_validators=lambda _, __, ___: [], + session_required=True, + user_context=user_context, + ) + + assert session is not None + + body = await api_options.request.json() + if body is None: + raise_bad_input_exception("Please provide a JSON body") + + totp = body.get("totp") + + if totp is None or not isinstance(totp, str): + raise Exception("totp is required and must be a string") + + response = await api_implementation.verify_totp_post( + totp=totp, + options=api_options, + session=session, + user_context=user_context, + ) + + return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/totp/asyncio/__init__.py b/supertokens_python/recipe/totp/asyncio/__init__.py new file mode 100644 index 000000000..5d83bbabb --- /dev/null +++ b/supertokens_python/recipe/totp/asyncio/__init__.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Union, Optional +from ..recipe import TOTPRecipe +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UnknownUserIdError, + UpdateDeviceOkResult, + UnknownDeviceError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, +) + + +async def create_device( + user_id: str, + user_identifier_info: Optional[str] = None, + device_name: Optional[str] = None, + skew: Optional[int] = None, + period: Optional[int] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.create_device( + user_id, + user_identifier_info, + device_name, + skew, + period, + user_context, + ) + + +async def update_device( + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.update_device( + user_id, + existing_device_name, + new_device_name, + user_context, + ) + + +async def list_devices( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> ListDevicesOkResult: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.list_devices( + user_id, + user_context, + ) + + +async def remove_device( + user_id: str, + device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> RemoveDeviceOkResult: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.remove_device( + user_id, + device_name, + user_context, + ) + + +async def verify_device( + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + VerifyDeviceOkResult, UnknownDeviceError, InvalidTOTPError, LimitReachedError +]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.verify_device( + tenant_id, + user_id, + device_name, + totp, + user_context, + ) + + +async def verify_totp( + tenant_id: str, + user_id: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[VerifyTOTPOkResult, UnknownUserIdError, InvalidTOTPError, LimitReachedError]: + if user_context is None: + user_context = {} + return await TOTPRecipe.get_instance_or_throw().recipe_implementation.verify_totp( + tenant_id, + user_id, + totp, + user_context, + ) diff --git a/supertokens_python/recipe/totp/constants.py b/supertokens_python/recipe/totp/constants.py new file mode 100644 index 000000000..bf421c46c --- /dev/null +++ b/supertokens_python/recipe/totp/constants.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +CREATE_TOTP_DEVICE = "/totp/device" +LIST_TOTP_DEVICES = "/totp/device/list" +REMOVE_TOTP_DEVICE = "/totp/device/remove" +VERIFY_TOTP_DEVICE = "/totp/device/verify" +VERIFY_TOTP = "/totp/verify" diff --git a/supertokens_python/recipe/totp/interfaces.py b/supertokens_python/recipe/totp/interfaces.py new file mode 100644 index 000000000..64c45f783 --- /dev/null +++ b/supertokens_python/recipe/totp/interfaces.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from typing import Dict, Any, Union, TYPE_CHECKING, Optional +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from .types import ( + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UpdateDeviceOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + VerifyTOTPOkResult, + InvalidTOTPError, + LimitReachedError, + UnknownDeviceError, + ListDevicesOkResult, + TOTPNormalisedConfig, + ) + from supertokens_python.recipe.session import SessionContainer + from supertokens_python import AppInfo + from supertokens_python.framework import BaseRequest, BaseResponse + from supertokens_python.recipe.totp.recipe import TOTPRecipe + from supertokens_python.types import GeneralErrorResponse + + +class RecipeInterface(ABC): + @abstractmethod + async def get_user_identifier_info_for_user_id( + self, user_id: str, user_context: Dict[str, Any] + ) -> Union[ + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + ]: + pass + + @abstractmethod + async def create_device( + self, + user_id: str, + user_identifier_info: Optional[str], + device_name: Optional[str], + skew: Optional[int], + period: Optional[int], + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError,]: + pass + + @abstractmethod + async def update_device( + self, + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Dict[str, Any], + ) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError,]: + pass + + @abstractmethod + async def list_devices( + self, user_id: str, user_context: Dict[str, Any] + ) -> ListDevicesOkResult: + pass + + @abstractmethod + async def remove_device( + self, user_id: str, device_name: str, user_context: Dict[str, Any] + ) -> RemoveDeviceOkResult: + pass + + @abstractmethod + async def verify_device( + self, + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + ]: + pass + + @abstractmethod + async def verify_totp( + self, tenant_id: str, user_id: str, totp: str, user_context: Dict[str, Any] + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + ]: + pass + + +class APIOptions: + def __init__( + self, + request: BaseRequest, + response: BaseResponse, + recipe_id: str, + config: TOTPNormalisedConfig, + recipe_implementation: RecipeInterface, + app_info: AppInfo, + recipe_instance: TOTPRecipe, + ): + self.request: BaseRequest = request + self.response: BaseResponse = response + self.recipe_id: str = recipe_id + self.config = config + self.recipe_implementation: RecipeInterface = recipe_implementation + self.app_info = app_info + self.recipe_instance = recipe_instance + + +class APIInterface(ABC): + def __init__(self): + self.disable_create_device_post = False + self.disable_list_devices_get = False + self.disable_remove_device_post = False + self.disable_verify_device_post = False + self.disable_verify_totp_post = False + + @abstractmethod + async def create_device_post( + self, + device_name: Union[str, None], + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, GeneralErrorResponse]: + pass + + @abstractmethod + async def list_devices_get( + self, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ListDevicesOkResult, GeneralErrorResponse]: + pass + + @abstractmethod + async def remove_device_post( + self, + device_name: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[RemoveDeviceOkResult, GeneralErrorResponse]: + pass + + @abstractmethod + async def verify_device_post( + self, + device_name: str, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + pass + + @abstractmethod + async def verify_totp_post( + self, + totp: str, + options: APIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + GeneralErrorResponse, + ]: + pass diff --git a/supertokens_python/recipe/totp/recipe.py b/supertokens_python/recipe/totp/recipe.py new file mode 100644 index 000000000..bcb5ba23c --- /dev/null +++ b/supertokens_python/recipe/totp/recipe.py @@ -0,0 +1,236 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from os import environ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.multifactorauth.types import ( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc, + GetFactorsSetupForUserFromOtherRecipesFunc, +) +from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe_module import APIHandled, RecipeModule +from supertokens_python.querier import Querier +from supertokens_python.types import User + +from .recipe_implementation import RecipeImplementation +from .api.implementation import APIImplementation +from .interfaces import APIInterface, RecipeInterface +from .utils import validate_and_normalise_user_input +from ...post_init_callbacks import PostSTInitCallbacks +from ..multifactorauth.recipe import MultiFactorAuthRecipe + +if TYPE_CHECKING: + from supertokens_python.framework.request import BaseRequest + from supertokens_python.framework.response import BaseResponse + from supertokens_python.supertokens import AppInfo + +from supertokens_python.exceptions import SuperTokensError, raise_general_exception + +from .api.list_devices import handle_list_devices_api +from .api.create_device import handle_create_device_api +from .api.remove_device import handle_remove_device_api +from .api.verify_device import handle_verify_device_api +from .api.verify_totp import handle_verify_totp_api +from .constants import ( + CREATE_TOTP_DEVICE, + VERIFY_TOTP_DEVICE, + VERIFY_TOTP, + LIST_TOTP_DEVICES, + REMOVE_TOTP_DEVICE, +) +from .interfaces import APIOptions +from .types import TOTPConfig + + +class TOTPRecipe(RecipeModule): + recipe_id = "totp" + __instance = None + + def __init__( + self, + recipe_id: str, + app_info: AppInfo, + config: Union[TOTPConfig, None] = None, + ): + super().__init__(recipe_id, app_info) + self.config = validate_and_normalise_user_input(app_info, config) + + recipe_implementation = RecipeImplementation( + Querier.get_instance(recipe_id), self.config + ) + self.recipe_implementation: RecipeInterface = ( + recipe_implementation + if self.config.override.functions is None + else self.config.override.functions(recipe_implementation) + ) + + api_implementation = APIImplementation() + self.api_implementation: APIInterface = ( + api_implementation + if self.config.override.apis is None + else self.config.override.apis(api_implementation) + ) + + def callback(): + mfa_instance = MultiFactorAuthRecipe.get_instance() + if mfa_instance is not None: + + async def f1(_: TenantConfig): + return ["totp"] + + async def f2(user: User, user_context: Dict[str, Any]) -> List[str]: + device_res = await TOTPRecipe.get_instance_or_throw().recipe_implementation.list_devices( + user_id=user.id, user_context=user_context + ) + for device in device_res.devices: + if device.verified: + return ["totp"] + return [] + + mfa_instance.add_func_to_get_all_available_secondary_factor_ids_from_other_recipes( + GetAllAvailableSecondaryFactorIdsFromOtherRecipesFunc(f1) + ) + mfa_instance.add_func_to_get_factors_setup_for_user_from_other_recipes( + GetFactorsSetupForUserFromOtherRecipesFunc(f2) + ) + + PostSTInitCallbacks.add_post_init_callback(callback) + + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: + return False + + def get_apis_handled(self) -> List[APIHandled]: + return [ + APIHandled( + NormalisedURLPath(CREATE_TOTP_DEVICE), + "post", + CREATE_TOTP_DEVICE, + self.api_implementation.disable_create_device_post, + ), + APIHandled( + NormalisedURLPath(LIST_TOTP_DEVICES), + "get", + LIST_TOTP_DEVICES, + self.api_implementation.disable_list_devices_get, + ), + APIHandled( + NormalisedURLPath(REMOVE_TOTP_DEVICE), + "post", + REMOVE_TOTP_DEVICE, + self.api_implementation.disable_remove_device_post, + ), + APIHandled( + NormalisedURLPath(VERIFY_TOTP_DEVICE), + "post", + VERIFY_TOTP_DEVICE, + self.api_implementation.disable_verify_device_post, + ), + APIHandled( + NormalisedURLPath(VERIFY_TOTP), + "post", + VERIFY_TOTP, + self.api_implementation.disable_verify_totp_post, + ), + ] + + async def handle_api_request( + self, + request_id: str, + tenant_id: str, + request: BaseRequest, + path: NormalisedURLPath, + method: str, + response: BaseResponse, + user_context: Dict[str, Any], + ): + api_options = APIOptions( + request, + response, + self.recipe_id, + self.config, + self.recipe_implementation, + self.get_app_info(), + self, + ) + if request_id == CREATE_TOTP_DEVICE: + return await handle_create_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == LIST_TOTP_DEVICES: + return await handle_list_devices_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == REMOVE_TOTP_DEVICE: + return await handle_remove_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == VERIFY_TOTP_DEVICE: + return await handle_verify_device_api( + tenant_id, self.api_implementation, api_options, user_context + ) + if request_id == VERIFY_TOTP: + return await handle_verify_totp_api( + tenant_id, self.api_implementation, api_options, user_context + ) + + return None + + async def handle_error( + self, + request: BaseRequest, + err: SuperTokensError, + response: BaseResponse, + user_context: Dict[str, Any], + ) -> BaseResponse: + raise err + + def get_all_cors_headers(self) -> List[str]: + return [] + + @staticmethod + def init( + config: Union[TOTPConfig, None] = None, + ): + def func(app_info: AppInfo): + if TOTPRecipe.__instance is None: + TOTPRecipe.__instance = TOTPRecipe( + TOTPRecipe.recipe_id, + app_info, + config, + ) + return TOTPRecipe.__instance + raise Exception( + "TOTP recipe has already been initialised. Please check your code for bugs." + ) + + return func + + @staticmethod + def get_instance_or_throw() -> TOTPRecipe: + if TOTPRecipe.__instance is not None: + return TOTPRecipe.__instance + raise_general_exception( + "Initialisation not done. Did you forget to call the SuperTokens.init function?" + ) + + @staticmethod + def reset(): + if ("SUPERTOKENS_ENV" not in environ) or ( + environ["SUPERTOKENS_ENV"] != "testing" + ): + raise_general_exception("calling testing function in non testing env") + TOTPRecipe.__instance = None diff --git a/supertokens_python/recipe/totp/recipe_implementation.py b/supertokens_python/recipe/totp/recipe_implementation.py new file mode 100644 index 000000000..ce213dd88 --- /dev/null +++ b/supertokens_python/recipe/totp/recipe_implementation.py @@ -0,0 +1,267 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from urllib.parse import quote + +from supertokens_python.normalised_url_path import NormalisedURLPath +from supertokens_python.recipe.totp.interfaces import ( + RecipeInterface, +) +from .types import ( + UnknownUserIdError, + UpdateDeviceOkResult, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + VerifyTOTPOkResult, + UserIdentifierInfoOkResult, + UserIdentifierInfoDoesNotExistError, + CreateDeviceOkResult, + Device, + DeviceAlreadyExistsError, + InvalidTOTPError, + LimitReachedError, + TOTPNormalisedConfig, + UnknownDeviceError, +) +from supertokens_python.asyncio import get_user + +if TYPE_CHECKING: + from supertokens_python.querier import Querier + + +class RecipeImplementation(RecipeInterface): + def __init__( + self, + querier: Querier, + config: TOTPNormalisedConfig, + ): + super().__init__() + self.querier = querier + self.config = config + + async def get_user_identifier_info_for_user_id( + self, user_id: str, user_context: Dict[str, Any] + ) -> Union[ + UserIdentifierInfoOkResult, + UnknownUserIdError, + UserIdentifierInfoDoesNotExistError, + ]: + user = await get_user(user_id, user_context) + + if user is None: + return UnknownUserIdError() + + primary_login_method = next( + ( + method + for method in user.login_methods + if method.recipe_user_id.get_as_string() == user.id + ), + None, + ) + + if primary_login_method is not None: + if primary_login_method.email is not None: + return UserIdentifierInfoOkResult(primary_login_method.email) + elif primary_login_method.phone_number is not None: + return UserIdentifierInfoOkResult(primary_login_method.phone_number) + + if user.emails: + return UserIdentifierInfoOkResult(user.emails[0]) + elif user.phone_numbers: + return UserIdentifierInfoOkResult(user.phone_numbers[0]) + + return UserIdentifierInfoDoesNotExistError() + + async def create_device( + self, + user_id: str, + user_identifier_info: Optional[str], + device_name: Optional[str], + skew: Optional[int], + period: Optional[int], + user_context: Dict[str, Any], + ) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError,]: + if user_identifier_info is None: + email_or_phone_info = await self.get_user_identifier_info_for_user_id( + user_id, user_context + ) + if isinstance(email_or_phone_info, UserIdentifierInfoOkResult): + user_identifier_info = email_or_phone_info.info + elif isinstance(email_or_phone_info, UnknownUserIdError): + return UnknownUserIdError() + + data = { + "userId": user_id, + "skew": skew if skew is not None else self.config.default_skew, + "period": period if period is not None else self.config.default_period, + } + if device_name is not None: + data["deviceName"] = device_name + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/totp/device"), + data, + user_context=user_context, + ) + + qr_code_string = ( + f"otpauth://totp/{quote(self.config.issuer)}" + f"{':' + quote(user_identifier_info) if user_identifier_info is not None else ''}" + f"?secret={response['secret']}&issuer={quote(self.config.issuer)}&digits=6" + f"&period={period if period is not None else self.config.default_period}" + ) + + return CreateDeviceOkResult( + device_name=response["deviceName"], + secret=response["secret"], + qr_code_string=qr_code_string, + ) + + async def update_device( + self, + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Dict[str, Any], + ) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError,]: + # Prepare the data for the API request + data = { + "userId": user_id, + "existingDeviceName": existing_device_name, + "newDeviceName": new_device_name, + } + + # Send a PUT request to update the device + resp = await self.querier.send_put_request( + NormalisedURLPath("/recipe/totp/device"), + data, + user_context=user_context, + ) + + # Handle the response based on the status + if resp["status"] == "OK": + return UpdateDeviceOkResult() + elif resp["status"] == "UNKNOWN_DEVICE_ERROR": + return UnknownDeviceError() + elif resp["status"] == "DEVICE_ALREADY_EXISTS_ERROR": + return DeviceAlreadyExistsError() + else: + # Raise an exception for unknown errors + raise Exception("Unknown error") + + async def list_devices( + self, user_id: str, user_context: Dict[str, Any] + ) -> ListDevicesOkResult: + params = {"userId": user_id} + response = await self.querier.send_get_request( + NormalisedURLPath("/recipe/totp/device/list"), + params, + user_context=user_context, + ) + return ListDevicesOkResult( + devices=[ + Device( + name=device["name"], + period=device["period"], + skew=device["skew"], + verified=device["verified"], + ) + for device in response["devices"] + ] + ) + + async def remove_device( + self, user_id: str, device_name: str, user_context: Dict[str, Any] + ) -> RemoveDeviceOkResult: + data = {"userId": user_id, "deviceName": device_name} + response = await self.querier.send_post_request( + NormalisedURLPath("/recipe/totp/device/remove"), + data, + user_context=user_context, + ) + return RemoveDeviceOkResult(did_device_exist=response["didDeviceExist"]) + + async def verify_device( + self, + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Dict[str, Any], + ) -> Union[ + VerifyDeviceOkResult, + UnknownDeviceError, + InvalidTOTPError, + LimitReachedError, + ]: + data = {"userId": user_id, "deviceName": device_name, "totp": totp} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/totp/device/verify"), + data, + user_context=user_context, + ) + if response["status"] == "OK": + return VerifyDeviceOkResult( + was_already_verified=response["wasAlreadyVerified"] + ) + elif response["status"] == "UNKNOWN_DEVICE_ERROR": + return UnknownDeviceError() + elif response["status"] == "INVALID_TOTP_ERROR": + return InvalidTOTPError( + current_number_of_failed_attempts=response[ + "currentNumberOfFailedAttempts" + ], + max_number_of_failed_attempts=response["maxNumberOfFailedAttempts"], + ) + elif response["status"] == "LIMIT_REACHED_ERROR": + return LimitReachedError( + retry_after_ms=response["retryAfterMs"], + ) + else: + raise Exception("Unknown error") + + async def verify_totp( + self, tenant_id: str, user_id: str, totp: str, user_context: Dict[str, Any] + ) -> Union[ + VerifyTOTPOkResult, + UnknownUserIdError, + InvalidTOTPError, + LimitReachedError, + ]: + data = {"userId": user_id, "totp": totp} + response = await self.querier.send_post_request( + NormalisedURLPath(f"{tenant_id}/recipe/totp/verify"), + data, + user_context=user_context, + ) + if response["status"] == "OK": + return VerifyTOTPOkResult() + elif response["status"] == "UNKNOWN_USER_ID_ERROR": + return UnknownUserIdError() + elif response["status"] == "INVALID_TOTP_ERROR": + return InvalidTOTPError( + current_number_of_failed_attempts=response[ + "currentNumberOfFailedAttempts" + ], + max_number_of_failed_attempts=response["maxNumberOfFailedAttempts"], + ) + elif response["status"] == "LIMIT_REACHED_ERROR": + return LimitReachedError( + retry_after_ms=response["retryAfterMs"], + ) + else: + raise Exception("Unknown error") diff --git a/supertokens_python/recipe/totp/syncio/__init__.py b/supertokens_python/recipe/totp/syncio/__init__.py new file mode 100644 index 000000000..24eb4d2ed --- /dev/null +++ b/supertokens_python/recipe/totp/syncio/__init__.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, Dict, Union, Optional + +from supertokens_python.async_to_sync_wrapper import sync + +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, + DeviceAlreadyExistsError, + UnknownUserIdError, + UpdateDeviceOkResult, + UnknownDeviceError, + ListDevicesOkResult, + RemoveDeviceOkResult, + VerifyDeviceOkResult, + InvalidTOTPError, + LimitReachedError, + VerifyTOTPOkResult, +) + + +def create_device( + user_id: str, + user_identifier_info: Optional[str] = None, + device_name: Optional[str] = None, + skew: Optional[int] = None, + period: Optional[int] = None, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[CreateDeviceOkResult, DeviceAlreadyExistsError, UnknownUserIdError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import create_device as async_func + + return sync( + async_func( + user_id, user_identifier_info, device_name, skew, period, user_context + ) + ) + + +def update_device( + user_id: str, + existing_device_name: str, + new_device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[UpdateDeviceOkResult, UnknownDeviceError, DeviceAlreadyExistsError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import update_device as async_func + + return sync( + async_func(user_id, existing_device_name, new_device_name, user_context) + ) + + +def list_devices( + user_id: str, + user_context: Optional[Dict[str, Any]] = None, +) -> ListDevicesOkResult: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import list_devices as async_func + + return sync(async_func(user_id, user_context)) + + +def remove_device( + user_id: str, + device_name: str, + user_context: Optional[Dict[str, Any]] = None, +) -> RemoveDeviceOkResult: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import remove_device as async_func + + return sync(async_func(user_id, device_name, user_context)) + + +def verify_device( + tenant_id: str, + user_id: str, + device_name: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[ + VerifyDeviceOkResult, UnknownDeviceError, InvalidTOTPError, LimitReachedError +]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import verify_device as async_func + + return sync(async_func(tenant_id, user_id, device_name, totp, user_context)) + + +def verify_totp( + tenant_id: str, + user_id: str, + totp: str, + user_context: Optional[Dict[str, Any]] = None, +) -> Union[VerifyTOTPOkResult, UnknownUserIdError, InvalidTOTPError, LimitReachedError]: + if user_context is None: + user_context = {} + + from supertokens_python.recipe.totp.asyncio import verify_totp as async_func + + return sync(async_func(tenant_id, user_id, totp, user_context)) diff --git a/supertokens_python/recipe/totp/types.py b/supertokens_python/recipe/totp/types.py new file mode 100644 index 000000000..f9599bf65 --- /dev/null +++ b/supertokens_python/recipe/totp/types.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any, Callable, Optional +from typing_extensions import Literal +from .interfaces import RecipeInterface, APIInterface + +from supertokens_python.types import APIResponse + + +class OkResult(APIResponse): + def __init__(self): + self.status: Literal["OK"] = "OK" + + +class UserIdentifierInfoOkResult(OkResult): + def __init__(self, info: str): + super().__init__() + self.info: str = info + + def to_json(self) -> Dict[str, Any]: + raise NotImplementedError() + + +class UnknownUserIdError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_USER_ID_ERROR"] = "UNKNOWN_USER_ID_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class UserIdentifierInfoDoesNotExistError: + def __init__(self): + self.status: Literal[ + "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + ] = "USER_IDENTIFIER_INFO_DOES_NOT_EXIST_ERROR" + + +class CreateDeviceOkResult(OkResult): + def __init__(self, device_name: str, secret: str, qr_code_string: str): + super().__init__() + self.device_name: str = device_name + self.secret: str = secret + self.qr_code_string: str = qr_code_string + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "deviceName": self.device_name, + "secret": self.secret, + "qrCodeString": self.qr_code_string, + } + + +class DeviceAlreadyExistsError(APIResponse): + def __init__(self): + self.status: Literal[ + "DEVICE_ALREADY_EXISTS_ERROR" + ] = "DEVICE_ALREADY_EXISTS_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class UpdateDeviceOkResult(OkResult): + def to_json(self) -> Dict[str, Any]: + raise NotImplementedError() + + +class UnknownDeviceError(APIResponse): + def __init__(self): + self.status: Literal["UNKNOWN_DEVICE_ERROR"] = "UNKNOWN_DEVICE_ERROR" + + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class Device(APIResponse): + def __init__(self, name: str, period: int, skew: int, verified: bool): + self.name: str = name + self.period: int = period + self.skew: int = skew + self.verified: bool = verified + + def to_json(self) -> Dict[str, Any]: + return { + "name": self.name, + "period": self.period, + "skew": self.skew, + "verified": self.verified, + } + + +class ListDevicesOkResult(OkResult): + def __init__(self, devices: List[Device]): + super().__init__() + self.devices: List[Device] = devices + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "devices": [device.to_json() for device in self.devices], + } + + +class RemoveDeviceOkResult(OkResult): + def __init__(self, did_device_exist: bool): + super().__init__() + self.did_device_exist: bool = did_device_exist + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "didDeviceExist": self.did_device_exist, + } + + +class VerifyDeviceOkResult(OkResult): + def __init__( + self, + was_already_verified: bool, + ): + super().__init__() + self.was_already_verified: bool = was_already_verified + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "wasAlreadyVerified": self.was_already_verified, + } + + +class InvalidTOTPError(APIResponse): + def __init__( + self, current_number_of_failed_attempts: int, max_number_of_failed_attempts: int + ): + self.status: Literal["INVALID_TOTP_ERROR"] = "INVALID_TOTP_ERROR" + self.current_number_of_failed_attempts: int = current_number_of_failed_attempts + self.max_number_of_failed_attempts: int = max_number_of_failed_attempts + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "currentNumberOfFailedAttempts": self.current_number_of_failed_attempts, + "maxNumberOfFailedAttempts": self.max_number_of_failed_attempts, + } + + +class LimitReachedError(APIResponse): + def __init__(self, retry_after_ms: int): + self.status: Literal["LIMIT_REACHED_ERROR"] = "LIMIT_REACHED_ERROR" + self.retry_after_ms: int = retry_after_ms + + def to_json(self) -> Dict[str, Any]: + return { + "status": self.status, + "retryAfterMs": self.retry_after_ms, + } + + +class VerifyTOTPOkResult(OkResult): + def to_json(self) -> Dict[str, Any]: + return {"status": self.status} + + +class OverrideConfig: + def __init__( + self, + functions: Optional[Callable[[RecipeInterface], RecipeInterface]] = None, + apis: Optional[Callable[[APIInterface], APIInterface]] = None, + ): + self.functions = functions + self.apis = apis + + +class TOTPConfig: + def __init__( + self, + issuer: Optional[str] = None, + default_skew: Optional[int] = None, + default_period: Optional[int] = None, + override: Optional[OverrideConfig] = None, + ): + self.issuer = issuer + self.default_skew = default_skew + self.default_period = default_period + self.override = override + + +class TOTPNormalisedConfig: + def __init__( + self, + issuer: str, + default_skew: int, + default_period: int, + override: OverrideConfig, + ): + self.issuer = issuer + self.default_skew = default_skew + self.default_period = default_period + self.override = override diff --git a/supertokens_python/recipe/totp/utils.py b/supertokens_python/recipe/totp/utils.py new file mode 100644 index 000000000..4684fa7f8 --- /dev/null +++ b/supertokens_python/recipe/totp/utils.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from typing import Union + +from supertokens_python import AppInfo +from .types import TOTPConfig, TOTPNormalisedConfig, OverrideConfig + + +def validate_and_normalise_user_input( + app_info: AppInfo, config: Union[TOTPConfig, None] +) -> TOTPNormalisedConfig: + if config is None: + config = TOTPConfig() + + issuer = config.issuer if config.issuer is not None else app_info.app_name + default_skew = config.default_skew if config.default_skew is not None else 1 + default_period = config.default_period if config.default_period is not None else 30 + + if config.override is None: + override = OverrideConfig() + else: + override = OverrideConfig( + functions=config.override.functions, + apis=config.override.apis, + ) + + return TOTPNormalisedConfig( + issuer=issuer, + default_skew=default_skew, + default_period=default_period, + override=override, + ) diff --git a/supertokens_python/recipe/usermetadata/recipe.py b/supertokens_python/recipe/usermetadata/recipe.py index 763390c80..6d45abfe2 100644 --- a/supertokens_python/recipe/usermetadata/recipe.py +++ b/supertokens_python/recipe/usermetadata/recipe.py @@ -15,7 +15,7 @@ from __future__ import annotations from os import environ -from typing import List, Union, Optional, Dict, Any +from typing import List, Union, Optional, Dict, Any, TYPE_CHECKING from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse @@ -31,7 +31,9 @@ validate_and_normalise_user_input, ) from supertokens_python.recipe_module import APIHandled, RecipeModule -from supertokens_python.supertokens import AppInfo + +if TYPE_CHECKING: + from supertokens_python.supertokens import AppInfo from .utils import InputOverrideConfig diff --git a/supertokens_python/recipe/usermetadata/utils.py b/supertokens_python/recipe/usermetadata/utils.py index 66174196c..7e059cf74 100644 --- a/supertokens_python/recipe/usermetadata/utils.py +++ b/supertokens_python/recipe/usermetadata/utils.py @@ -20,10 +20,10 @@ APIInterface, RecipeInterface, ) -from supertokens_python.supertokens import AppInfo if TYPE_CHECKING: from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + from supertokens_python.supertokens import AppInfo class InputOverrideConfig: diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index a4f4189b4..844b41895 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -27,6 +27,7 @@ from supertokens_python.recipe.userroles.utils import validate_and_normalise_user_input from supertokens_python.recipe_module import APIHandled, RecipeModule from supertokens_python.supertokens import AppInfo +from supertokens_python.types import RecipeUserId from ...post_init_callbacks import PostSTInitCallbacks from ..session import SessionRecipe @@ -151,7 +152,11 @@ def __init__(self) -> None: default_max_age_in_sec = 300 async def fetch_value( - user_id: str, tenant_id: str, user_context: Dict[str, Any] + user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> List[str]: recipe = UserRolesRecipe.get_instance() @@ -186,7 +191,11 @@ def __init__(self) -> None: default_max_age_in_sec = 300 async def fetch_value( - user_id: str, tenant_id: str, user_context: Dict[str, Any] + user_id: str, + _recipe_user_id: RecipeUserId, + tenant_id: str, + _current_payload: Dict[str, Any], + user_context: Dict[str, Any], ) -> List[str]: recipe = UserRolesRecipe.get_instance() res = await recipe.recipe_implementation.get_roles_for_user( diff --git a/supertokens_python/supertokens.py b/supertokens_python/supertokens.py index b47068823..6aae8aa9b 100644 --- a/supertokens_python/supertokens.py +++ b/supertokens_python/supertokens.py @@ -26,7 +26,7 @@ ) -from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT, USER_DELETE, USERS +from .constants import FDI_KEY_HEADER, RID_KEY_HEADER, USER_COUNT from .exceptions import SuperTokensError from .interfaces import ( CreateUserIdMappingOkResult, @@ -42,7 +42,6 @@ from .normalised_url_path import NormalisedURLPath from .post_init_callbacks import PostSTInitCallbacks from .querier import Querier -from .types import ThirdPartyInfo, User, UsersResponse from .utils import ( get_rid_from_header, get_top_level_domain_for_same_site_resolution, @@ -256,22 +255,36 @@ def __init__( "Please provide at least one recipe to the supertokens.init function call" ) - from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe - multitenancy_found = False + totp_found = False + user_metadata_found = False + multi_factor_auth_found = False def make_recipe(recipe: Callable[[AppInfo], RecipeModule]) -> RecipeModule: - nonlocal multitenancy_found + nonlocal multitenancy_found, totp_found, user_metadata_found, multi_factor_auth_found recipe_module = recipe(self.app_info) - if recipe_module.get_recipe_id() == MultitenancyRecipe.recipe_id: + if recipe_module.get_recipe_id() == "multitenancy": multitenancy_found = True + elif recipe_module.get_recipe_id() == "usermetadata": + user_metadata_found = True + elif recipe_module.get_recipe_id() == "multifactorauth": + multi_factor_auth_found = True + elif recipe_module.get_recipe_id() == "totp": + totp_found = True return recipe_module self.recipe_modules: List[RecipeModule] = list(map(make_recipe, recipe_list)) if not multitenancy_found: - recipe = MultitenancyRecipe.init()(self.app_info) - self.recipe_modules.append(recipe) + from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe + + self.recipe_modules.append(MultitenancyRecipe.init()(self.app_info)) + if totp_found and not multi_factor_auth_found: + raise Exception("Please initialize the MultiFactorAuth recipe to use TOTP.") + if not user_metadata_found: + from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + + self.recipe_modules.append(UserMetadataRecipe.init()(self.app_info)) self.telemetry = ( telemetry @@ -307,6 +320,9 @@ def reset(): environ["SUPERTOKENS_ENV"] != "testing" ): raise_general_exception("calling testing function in non testing env") + from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe + + UserMetadataRecipe.reset() Querier.reset() Supertokens.__instance = None @@ -351,87 +367,6 @@ async def get_user_count( # pylint: disable=no-self-use return int(response["count"]) - async def delete_user( # pylint: disable=no-self-use - self, - user_id: str, - user_context: Optional[Dict[str, Any]], - ) -> None: - querier = Querier.get_instance(None) - - cdi_version = await querier.get_api_version(user_context) - - if is_version_gte(cdi_version, "2.10"): - await querier.send_post_request( - NormalisedURLPath(USER_DELETE), - {"userId": user_id}, - user_context=user_context, - ) - - return None - raise_general_exception("Please upgrade the SuperTokens core to >= 3.7.0") - - async def get_users( # pylint: disable=no-self-use - self, - tenant_id: str, - time_joined_order: Literal["ASC", "DESC"], - limit: Union[int, None], - pagination_token: Union[str, None], - include_recipe_ids: Union[None, List[str]], - query: Union[Dict[str, str], None], - user_context: Optional[Dict[str, Any]], - ) -> UsersResponse: - - querier = Querier.get_instance(None) - params = {"timeJoinedOrder": time_joined_order} - if limit is not None: - params = {"limit": limit, **params} - if pagination_token is not None: - params = {"paginationToken": pagination_token, **params} - - include_recipe_ids_str = None - if include_recipe_ids is not None: - include_recipe_ids_str = ",".join(include_recipe_ids) - params = {"includeRecipeIds": include_recipe_ids_str, **params} - - if query is not None: - params = {**params, **query} - - response = await querier.send_get_request( - NormalisedURLPath(f"/{tenant_id}{USERS}"), params, user_context=user_context - ) - next_pagination_token = None - if "nextPaginationToken" in response: - next_pagination_token = response["nextPaginationToken"] - users_list = response["users"] - users: List[User] = [] - for user in users_list: - recipe_id = user["recipeId"] - user_obj = user["user"] - third_party = None - if "thirdParty" in user_obj: - third_party = ThirdPartyInfo( - user_obj["thirdParty"]["userId"], user_obj["thirdParty"]["id"] - ) - email = None - if "email" in user_obj: - email = user_obj["email"] - phone_number = None - if "phoneNumber" in user_obj: - phone_number = user_obj["phoneNumber"] - users.append( - User( - recipe_id, - user_obj["id"], - user_obj["timeJoined"], - email, - phone_number, - third_party, - user_obj["tenantIds"], - ) - ) - - return UsersResponse(users, next_pagination_token) - async def create_user_id_mapping( # pylint: disable=no-self-use self, supertokens_user_id: str, diff --git a/supertokens_python/syncio/__init__.py b/supertokens_python/syncio/__init__.py index b1557b074..e7b50292f 100644 --- a/supertokens_python/syncio/__init__.py +++ b/supertokens_python/syncio/__init__.py @@ -25,7 +25,7 @@ UserIdMappingAlreadyExistsError, UserIDTypes, ) -from supertokens_python.types import UsersResponse +from supertokens_python.types import AccountInfo, User def get_users_oldest_first( @@ -35,11 +35,12 @@ def get_users_oldest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: +): + from supertokens_python.asyncio import get_users_oldest_first + return sync( - Supertokens.get_instance().get_users( + get_users_oldest_first( tenant_id, - "ASC", limit, pagination_token, include_recipe_ids, @@ -56,11 +57,12 @@ def get_users_newest_first( include_recipe_ids: Union[None, List[str]] = None, query: Union[None, Dict[str, str]] = None, user_context: Optional[Dict[str, Any]] = None, -) -> UsersResponse: +): + from supertokens_python.asyncio import get_users_newest_first + return sync( - Supertokens.get_instance().get_users( + get_users_newest_first( tenant_id, - "DESC", limit, pagination_token, include_recipe_ids, @@ -82,8 +84,22 @@ def get_user_count( ) -def delete_user(user_id: str, user_context: Optional[Dict[str, Any]] = None) -> None: - return sync(Supertokens.get_instance().delete_user(user_id, user_context)) +def delete_user( + user_id: str, + remove_all_linked_accounts: bool = True, + user_context: Optional[Dict[str, Any]] = None, +) -> None: + from supertokens_python.asyncio import delete_user + + return sync(delete_user(user_id, remove_all_linked_accounts, user_context)) + + +def get_user( + user_id: str, user_context: Optional[Dict[str, Any]] = None +) -> Optional[User]: + from supertokens_python.asyncio import get_user as async_get_user + + return sync(async_get_user(user_id, user_context)) def create_user_id_mapping( @@ -144,3 +160,20 @@ def update_or_delete_user_id_mapping_info( user_id, user_id_type, external_user_id_info, user_context ) ) + + +def list_users_by_account_info( + tenant_id: str, + account_info: AccountInfo, + do_union_of_account_info: bool = False, + user_context: Optional[Dict[str, Any]] = None, +) -> List[User]: + from supertokens_python.asyncio import ( + list_users_by_account_info as async_list_users_by_account_info, + ) + + return sync( + async_list_users_by_account_info( + tenant_id, account_info, do_union_of_account_info, user_context + ) + ) diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 1c7b2f799..4e9b8ecd6 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -11,61 +11,224 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Awaitable, Dict, List, TypeVar, Union +from typing import Any, Awaitable, Dict, List, TypeVar, Union, Optional, TYPE_CHECKING +from phonenumbers import format_number, parse # type: ignore +import phonenumbers # type: ignore +from typing_extensions import Literal _T = TypeVar("_T") +if TYPE_CHECKING: + from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo -class ThirdPartyInfo: - def __init__(self, third_party_user_id: str, third_party_id: str): - self.user_id = third_party_user_id - self.id = third_party_id +class RecipeUserId: + def __init__(self, recipe_user_id: str): + self.recipe_user_id = recipe_user_id -class User: + def get_as_string(self) -> str: + return self.recipe_user_id + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RecipeUserId): + return self.recipe_user_id == other.recipe_user_id + return False + + +class AccountInfo: def __init__( self, - recipe_id: str, - user_id: str, - time_joined: int, + email: Optional[str] = None, + phone_number: Optional[str] = None, + third_party: Optional[ThirdPartyInfo] = None, + ): + self.email = email + self.phone_number = phone_number + self.third_party = third_party + + def to_json(self) -> Dict[str, Any]: + json_repo: Dict[str, Any] = {} + if self.email is not None: + json_repo["email"] = self.email + if self.phone_number is not None: + json_repo["phoneNumber"] = self.phone_number + if self.third_party is not None: + json_repo["thirdParty"] = { + "id": self.third_party.id, + "userId": self.third_party.user_id, + } + return json_repo + + +class LoginMethod(AccountInfo): + def __init__( + self, + recipe_id: Literal["emailpassword", "thirdparty", "passwordless"], + recipe_user_id: str, + tenant_ids: List[str], email: Union[str, None], phone_number: Union[str, None], - third_party_info: Union[ThirdPartyInfo, None], - tenant_ids: List[str], + third_party: Union[ThirdPartyInfo, None], + time_joined: int, + verified: bool, ): - self.recipe_id = recipe_id - self.user_id = user_id - self.email = email + super().__init__(email, phone_number, third_party) + self.recipe_id: Literal[ + "emailpassword", "thirdparty", "passwordless" + ] = recipe_id + self.recipe_user_id = RecipeUserId(recipe_user_id) + self.tenant_ids: List[str] = tenant_ids self.time_joined = time_joined - self.third_party_info = third_party_info - self.phone_number = phone_number - self.tenant_ids = tenant_ids + self.verified = verified + + def __eq__(self, other: Any) -> bool: + if isinstance(other, LoginMethod): + return ( + self.recipe_id == other.recipe_id + and self.recipe_user_id == other.recipe_user_id + and self.tenant_ids == other.tenant_ids + and self.email == other.email + and self.phone_number == other.phone_number + and self.third_party == other.third_party + and self.time_joined == other.time_joined + and self.verified == other.verified + ) + return False + + def has_same_email_as(self, email: Union[str, None]) -> bool: + if email is None: + return False + return ( + self.email is not None + and self.email.lower().strip() == email.lower().strip() + ) + + def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: + if phone_number is None: + return False + + cleaned_phone = phone_number.strip() + try: + cleaned_phone = format_number( + parse(phone_number, None), phonenumbers.PhoneNumberFormat.E164 + ) + except Exception: + pass # here we just use the stripped version + + return self.phone_number is not None and self.phone_number == cleaned_phone + + def has_same_third_party_info_as( + self, third_party: Union[ThirdPartyInfo, None] + ) -> bool: + if third_party is None: + return False + return ( + self.third_party is not None + and self.third_party.id.strip() == third_party.id.strip() + and self.third_party.user_id.strip() == third_party.user_id.strip() + ) def to_json(self) -> Dict[str, Any]: - res: Dict[str, Any] = { + return { "recipeId": self.recipe_id, - "user": { - "id": self.user_id, - "timeJoined": self.time_joined, - "tenantIds": self.tenant_ids, - }, + "recipeUserId": self.recipe_user_id.get_as_string(), + "tenantIds": self.tenant_ids, + "email": self.email, + "phoneNumber": self.phone_number, + "thirdParty": self.third_party.to_json() if self.third_party else None, + "timeJoined": self.time_joined, + "verified": self.verified, } - if self.email is not None: - res["user"]["email"] = self.email - if self.phone_number is not None: - res["user"]["phoneNumber"] = self.phone_number - if self.third_party_info is not None: - res["user"]["thirdParty"] = self.third_party_info.__dict__ + @staticmethod + def from_json(json: Dict[str, Any]) -> "LoginMethod": + from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo as TPI + + return LoginMethod( + recipe_id=json["recipeId"], + recipe_user_id=json["recipeUserId"], + tenant_ids=json["tenantIds"], + email=( + json["email"] if "email" in json and json["email"] is not None else None + ), + phone_number=( + json["phoneNumber"] + if "phoneNumber" in json and json["phoneNumber"] is not None + else None + ), + third_party=( + ( + TPI(json["thirdParty"]["userId"], json["thirdParty"]["id"]) + if "thirdParty" in json and json["thirdParty"] is not None + else None + ) + ), + time_joined=json["timeJoined"], + verified=json["verified"], + ) - return res +class User: + def __init__( + self, + user_id: str, + is_primary_user: bool, + tenant_ids: List[str], + emails: List[str], + phone_numbers: List[str], + third_party: List[ThirdPartyInfo], + login_methods: List[LoginMethod], + time_joined: int, + ): + self.id = user_id + self.is_primary_user = is_primary_user + self.tenant_ids = tenant_ids + self.emails = emails + self.phone_numbers = phone_numbers + self.third_party = third_party + self.login_methods = login_methods + self.time_joined = time_joined + + def __eq__(self, other: Any) -> bool: + if isinstance(other, User): + return ( + self.id == other.id + and self.is_primary_user == other.is_primary_user + and self.tenant_ids == other.tenant_ids + and self.emails == other.emails + and self.phone_numbers == other.phone_numbers + and self.third_party == other.third_party + and self.login_methods == other.login_methods + and self.time_joined == other.time_joined + ) + return False + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "isPrimaryUser": self.is_primary_user, + "tenantIds": self.tenant_ids, + "emails": self.emails, + "phoneNumbers": self.phone_numbers, + "thirdParty": self.third_party, + "loginMethods": [lm.to_json() for lm in self.login_methods], + "timeJoined": self.time_joined, + } -class UsersResponse: - def __init__(self, users: List[User], next_pagination_token: Union[str, None]): - self.users: List[User] = users - self.next_pagination_token: Union[str, None] = next_pagination_token + @staticmethod + def from_json(json: Dict[str, Any]) -> "User": + return User( + user_id=json["id"], + is_primary_user=json["isPrimaryUser"], + tenant_ids=json["tenantIds"], + emails=json["emails"], + phone_numbers=json["phoneNumbers"], + third_party=json["thirdParty"], + login_methods=[LoginMethod.from_json(lm) for lm in json["loginMethods"]], + time_joined=json["timeJoined"], + ) class APIResponse(ABC): diff --git a/supertokens_python/utils.py b/supertokens_python/utils.py index aa202a0ec..b199e55a6 100644 --- a/supertokens_python/utils.py +++ b/supertokens_python/utils.py @@ -44,9 +44,13 @@ from supertokens_python.framework.response import BaseResponse from supertokens_python.logger import log_debug_message -from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER +if TYPE_CHECKING: + from supertokens_python.recipe.session import SessionContainer + +from .constants import ERROR_MESSAGE_KEY, RID_KEY_HEADER, FDI_KEY_HEADER from .exceptions import raise_general_exception from .types import MaybeAwaitable +from supertokens_python.types import User _T = TypeVar("_T") @@ -179,6 +183,13 @@ def utf_base64decode(s: str, urlsafe: bool) -> str: return b64decode(s.encode("utf-8")).decode("utf-8") +def encode_base64(value: str) -> str: + """ + Encode the passed value to base64 and return the encoded value. + """ + return b64encode(value.encode()).decode() + + def get_filtered_list(func: Callable[[_T], bool], given_list: List[_T]) -> List[_T]: return list(filter(func, given_list)) @@ -286,7 +297,81 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str: "Please make sure that the apiDomain and websiteDomain have correct values" ) - return parsed_url.domain + "." + parsed_url.suffix # type: ignore + return parsed_url.domain + "." + parsed_url.suffix + + +def get_backwards_compatible_user_info( + req: BaseRequest, + user_info: User, + session_container: SessionContainer, + created_new_recipe_user: Union[bool, None], + user_context: Dict[str, Any], +) -> Dict[str, Any]: + resp: Dict[str, Any] = {} + # (>= 1.18 && < 2.0) || >= 3.0: This is because before 1.18, and between 2 and 3, FDI does not + # support account linking. + if ( + has_greater_than_equal_to_fdi(req, "1.18") + and not has_greater_than_equal_to_fdi(req, "2.0") + ) or has_greater_than_equal_to_fdi(req, "3.0"): + resp = {"user": user_info.to_json()} + + if created_new_recipe_user is not None: + resp["createdNewRecipeUser"] = created_new_recipe_user + return resp + + login_method = next( + ( + lm + for lm in user_info.login_methods + if lm.recipe_user_id.get_as_string() + == session_container.get_recipe_user_id(user_context).get_as_string() + ), + None, + ) + + if login_method is None: + # we pick the oldest login method here for the user. + # this can happen in case the user is implementing something like + # MFA where the session remains the same during the second factor as well. + login_method = min(user_info.login_methods, key=lambda lm: lm.time_joined) + + user_obj: Dict[str, Any] = { + "id": user_info.id, # we purposely use this instead of the loginmethod's recipeUserId because if the oldest login method is deleted, then this userID should remain the same. + "timeJoined": login_method.time_joined, + } + if login_method.third_party: + user_obj["thirdParty"] = login_method.third_party + if login_method.email: + user_obj["email"] = login_method.email + if login_method.phone_number: + user_obj["phoneNumber"] = login_method.phone_number + + resp = {"user": user_obj} + + if created_new_recipe_user is not None: + resp["createdNewUser"] = created_new_recipe_user + + return resp + + +def get_latest_fdi_version_from_fdi_list(fdi_header_value: str) -> str: + versions = fdi_header_value.split(",") + max_version_str = versions[0] + for version in versions[1:]: + max_version_str = _get_max_version(max_version_str, version) + return max_version_str + + +def has_greater_than_equal_to_fdi(req: BaseRequest, version: str) -> bool: + request_fdi = req.get_header(FDI_KEY_HEADER) + if request_fdi is None: + # By default we assume they want to use the latest FDI, this also helps with tests + return True + request_fdi = get_latest_fdi_version_from_fdi_list(request_fdi) + if request_fdi == version or _get_max_version(version, request_fdi) != version: + return True + return False class RWMutex: @@ -345,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): def normalise_email(email: str) -> str: return email.strip().lower() + + +def get_normalised_should_try_linking_with_session_user_flag( + req: BaseRequest, body: Dict[str, Any] +) -> Optional[bool]: + if has_greater_than_equal_to_fdi(req, "3.1"): + return body.get("shouldTryLinkingWithSessionUser", False) + return None diff --git a/tests/Django/test_django.py b/tests/Django/test_django.py index 6305a1b2e..66cd27357 100644 --- a/tests/Django/test_django.py +++ b/tests/Django/test_django.py @@ -41,6 +41,7 @@ from supertokens_python.recipe.session.framework.django.asyncio import verify_session import pytest +from supertokens_python.types import RecipeUserId from tests.utils import ( clean_st, reset, @@ -86,7 +87,7 @@ def get_cookies(response: HttpResponse) -> Dict[str, Any]: async def create_new_session_view(request: HttpRequest): - await create_new_session(request, "public", "user_id") + await create_new_session(request, "public", RecipeUserId("user_id")) return JsonResponse({"foo": "bar"}) @@ -501,6 +502,7 @@ async def test_thirdparty_parsing_works(self): f"http://localhost:3000/redirect?state={state.replace('=', '%3D')}&code={code}", ) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_multiple_emails(self): init( @@ -553,6 +555,7 @@ async def test_search_with_multiple_emails(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 1) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_email_t(self): init( @@ -603,6 +606,7 @@ async def test_search_with_email_t(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 5) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_email_iresh(self): init( @@ -655,6 +659,7 @@ async def test_search_with_email_iresh(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 0) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_phone_plus_one(self): init( @@ -710,6 +715,7 @@ async def test_search_with_phone_plus_one(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 3) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_phone_one_bracket(self): init( @@ -765,6 +771,7 @@ async def test_search_with_phone_one_bracket(self): self.assertEqual(response.status_code, 200) self.assertEqual(len(data_json["users"]), 0) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_provider_google(self): init( @@ -859,6 +866,7 @@ async def test_search_with_provider_google(self): data_json = json.loads(response.content) self.assertEqual(len(data_json["users"]), 3) + @override_settings(ALLOWED_HOSTS=["testserver"]) @pytest.mark.asyncio async def test_search_with_provider_google_and_phone_one(self): init( @@ -979,7 +987,7 @@ async def test_that_verify_session_return_401_if_access_token_is_not_sent_and_mi # Create a session and get access token s = await create_new_session_without_request_response( - "public", "userId", {}, {} + "public", RecipeUserId("userId"), {}, {} ) access_token = s.get_access_token() headers = {"HTTP_AUTHORIZATION": "Bearer " + access_token} diff --git a/tests/Fastapi/test_fastapi.py b/tests/Fastapi/test_fastapi.py index 33f3b5013..6a76a3168 100644 --- a/tests/Fastapi/test_fastapi.py +++ b/tests/Fastapi/test_fastapi.py @@ -16,6 +16,7 @@ from fastapi import Depends, FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -91,7 +92,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") @@ -135,12 +136,12 @@ async def custom_logout(request: Request): # type: ignore @app.post("/create") async def _create(request: Request): # type: ignore - await create_new_session(request, "public", "userId", {}, {}) + await create_new_session(request, "public", RecipeUserId("userId"), {}, {}) return "" @app.post("/create-throw") async def _create_throw(request: Request): # type: ignore - await create_new_session(request, "public", "userId", {}, {}) + await create_new_session(request, "public", RecipeUserId("userId"), {}, {}) raise UnauthorisedError("unauthorised") return TestClient(app) diff --git a/tests/Flask/test_flask.py b/tests/Flask/test_flask.py index d87292fe2..ec45aadc3 100644 --- a/tests/Flask/test_flask.py +++ b/tests/Flask/test_flask.py @@ -32,6 +32,7 @@ refresh_session, revoke_session, ) +from supertokens_python.types import RecipeUserId from tests.Flask.utils import extract_all_cookies from tests.utils import ( TEST_ACCESS_TOKEN_MAX_AGE_CONFIG_KEY, @@ -193,7 +194,7 @@ def t(): # type: ignore @app.route("/login") # type: ignore def login(): # type: ignore user_id = "userId" - create_new_session(request, "public", user_id, {}, {}) + create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return jsonify({"userId": user_id, "session": "ssss"}) @@ -753,7 +754,7 @@ def test_api(): # type: ignore @app.route("/login") # type: ignore def login(): # type: ignore user_id = "userId" - s = create_new_session(request, "public", user_id, {}, {}) + s = create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return jsonify({"user": s.get_user_id()}) @app.route("/ping") # type: ignore @@ -832,7 +833,9 @@ def test_that_verify_session_return_401_if_access_token_is_not_sent_and_middlewa assert res.status_code == 401 assert res.json == {"message": "unauthorised"} - s = create_new_session_without_request_response("public", "userId", {}, {}) + s = create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) res = client.get( "/verify", headers={"Authorization": "Bearer " + s.get_access_token()} ) @@ -882,7 +885,7 @@ def _(_): @app.route("/create-session") # type: ignore def create_session_api(): # type: ignore - create_new_session(request, "public", "userId", {}, {}) + create_new_session(request, "public", RecipeUserId("userId"), {}, {}) return jsonify({}) return app diff --git a/tests/auth-react/django3x/manage.py b/tests/auth-react/django3x/manage.py index be146f802..ef5aa4985 100644 --- a/tests/auth-react/django3x/manage.py +++ b/tests/auth-react/django3x/manage.py @@ -5,6 +5,7 @@ def main(): + os.environ.setdefault("SUPERTOKENS_ENV", "testing") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings") try: from django.core.management import execute_from_command_line diff --git a/tests/auth-react/django3x/mysite/settings.py b/tests/auth-react/django3x/mysite/settings.py index a8f7d3686..6ad791b27 100644 --- a/tests/auth-react/django3x/mysite/settings.py +++ b/tests/auth-react/django3x/mysite/settings.py @@ -30,7 +30,7 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -custom_init(None, None) +custom_init() ALLOWED_HOSTS = ["localhost"] diff --git a/tests/auth-react/django3x/mysite/store.py b/tests/auth-react/django3x/mysite/store.py index 37f0dd2ea..3f2ece28c 100644 --- a/tests/auth-react/django3x/mysite/store.py +++ b/tests/auth-react/django3x/mysite/store.py @@ -1,18 +1,27 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Literal -_LATEST_URL_WITH_TOKEN = None +latest_url_with_token = "" def save_url_with_token(url_with_token: str): - global _LATEST_URL_WITH_TOKEN - _LATEST_URL_WITH_TOKEN = url_with_token # type: ignore + global latest_url_with_token + latest_url_with_token = url_with_token def get_url_with_token() -> str: - return _LATEST_URL_WITH_TOKEN # type: ignore + return latest_url_with_token -_CODE_STORE: Dict[str, List[Dict[str, Any]]] = {} +code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None def save_code( @@ -20,8 +29,8 @@ def save_code( url_with_link_code: Union[str, None], user_input_code: Union[str, None], ): - global _CODE_STORE - codes = _CODE_STORE.get(pre_auth_session_id, []) + global code_store + codes = code_store.get(pre_auth_session_id, []) # replace sub string in url_with_link_code if url_with_link_code: url_with_link_code = url_with_link_code.replace( @@ -30,8 +39,8 @@ def save_code( codes.append( {"urlWithLinkCode": url_with_link_code, "userInputCode": user_input_code} ) - _CODE_STORE[pre_auth_session_id] = codes + code_store[pre_auth_session_id] = codes def get_codes(pre_auth_session_id: str) -> List[Dict[str, Any]]: - return _CODE_STORE.get(pre_auth_session_id, []) + return code_store.get(pre_auth_session_id, []) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index e5275d227..c3d03d502 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv from typing_extensions import Literal @@ -9,11 +9,14 @@ from supertokens_python.recipe import ( emailpassword, emailverification, + multifactorauth, passwordless, session, thirdparty, + totp, userroles, ) +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.interfaces import ( @@ -51,6 +54,10 @@ from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, ) @@ -61,12 +68,21 @@ ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import UserRolesRecipe -from supertokens_python.types import GeneralErrorResponse +from supertokens_python.types import GeneralErrorResponse, User from .store import save_code, save_url_with_token from supertokens_python.recipe import multitenancy +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe load_dotenv() @@ -112,9 +128,7 @@ async def send_email( ) -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -260,12 +274,34 @@ async def get_user_info( # pylint: disable=no-self-use ] -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + import mysite.store + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() @@ -277,6 +313,8 @@ def custom_init( DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -392,6 +430,8 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -405,12 +445,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -420,7 +467,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -440,6 +492,8 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -453,6 +507,8 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -483,7 +539,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): @@ -507,6 +563,8 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -521,6 +579,8 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -529,6 +589,8 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -539,12 +601,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -555,7 +625,13 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post @@ -563,20 +639,94 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation - if contact_method is not None and flow_type is not None: - if contact_method == "PHONE": + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + if mysite.store.enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in mysite.store.enabled_providers + ] + + if mysite.store.contact_method is not None and mysite.store.flow_type is not None: + if mysite.store.contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), - flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + flow_type=mysite.store.flow_type, + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), ) - elif contact_method == "EMAIL": + elif mysite.store.contact_method == "EMAIL": passwordless_init = passwordless.init( contact_config=ContactEmailOnlyConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), @@ -587,11 +737,11 @@ async def resend_code_post( else: passwordless_init = passwordless.init( contact_config=ContactEmailOrPhoneConfig(), - flow_type=flow_type, + flow_type=mysite.store.flow_type, email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -601,7 +751,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -610,34 +760,240 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mysite.store.mfa_info: + return mysite.store.mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mysite.store.mfa_info: + if factor_id not in mysite.store.mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mysite.store.mfa_info: + return mysite.store.mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mysite.store.mfa_info: + res.factors.already_setup = mysite.store.mfa_info["alreadySetup"][:] + + if "noContacts" in mysite.store.mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailverification.EmailDeliveryConfig( - CustomEPEmailService() + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), + ), + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), ), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mysite.store.mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **mysite.store.accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + if mysite.store.enabled_recipes is not None: + recipe_list = [ + item["init"] + for item in recipe_list + if item["id"] in mysite.store.enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( diff --git a/tests/auth-react/django3x/polls/urls.py b/tests/auth-react/django3x/polls/urls.py index b96151911..8b3cf999e 100644 --- a/tests/auth-react/django3x/polls/urls.py +++ b/tests/auth-react/django3x/polls/urls.py @@ -8,23 +8,49 @@ path("ping", views.ping, name="ping"), path("sessionInfo", views.session_info, name="sessionInfo"), path("token", views.token, name="token"), - path("test/setFlow", views.test_set_flow, name="setFlow"), - path("test/getDevice", views.test_get_device, name="getDevice"), - path("test/featureFlags", views.test_feature_flags, name="featureFlags"), - path("beforeeach", views.before_each, name="beforeeach"), + path("changeEmail", views.change_email, name="changeEmail"), # type: ignore + path("setupTenant", views.setup_tenant, name="setupTenant"), # type: ignore + path("removeTenant", views.remove_tenant, name="removeTenant"), # type: ignore + path( + "removeUserFromTenant", + views.remove_user_from_tenant, # type: ignore + name="removeUserFromTenant", + ), # type: ignore + path("addUserToTenant", views.add_user_to_tenant, name="addUserToTenant"), # type: ignore + path("test/setFlow", views.test_set_flow, name="setFlow"), # type: ignore + path( + "test/setAccountLinkingConfig", + views.test_set_account_linking_config, # type: ignore + name="setAccountLinkingConfig", + ), # type: ignore + path("setMFAInfo", views.set_mfa_info, name="setMfaInfo"), # type: ignore + path( + "addRequiredFactor", + views.add_required_factor, # type: ignore + name="addRequiredFactor", + ), # type: ignore + path( + "test/setEnabledRecipes", + views.test_set_enabled_recipes, # type: ignore + name="setEnabledRecipes", + ), + path("test/getTOTPCode", views.test_get_totp_code, name="getTotpCode"), # type: ignore + path("test/getDevice", views.test_get_device, name="getDevice"), # type: ignore + path("test/featureFlags", views.test_feature_flags, name="featureFlags"), # type: ignore + path("beforeeach", views.before_each, name="beforeeach"), # type: ignore ] mode = os.environ.get("APP_MODE", "asgi") if mode == "asgi": - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.unverify_email_api, name="unverifyEmail"), # type: ignore path("setRole", views.set_role_api, name="setRole"), # type: ignore path("checkRole", views.check_role_api, name="checkRole"), # type: ignore path("deleteUser", views.delete_user, name="deleteUser"), # type: ignore ] else: - urlpatterns += [ + urlpatterns += [ # type: ignore path("unverifyEmail", views.sync_unverify_email_api, name="unverifyEmail"), path("setRole", views.sync_set_role_api, name="setRole"), path("checkRole", views.sync_check_role_api, name="checkRole"), diff --git a/tests/auth-react/django3x/polls/views.py b/tests/auth-react/django3x/polls/views.py index 74cc3fd7f..c531c9ddc 100644 --- a/tests/auth-react/django3x/polls/views.py +++ b/tests/auth-react/django3x/polls/views.py @@ -15,15 +15,57 @@ import os from typing import List, Dict, Any -from django.conf import settings from django.http import HttpRequest, HttpResponse, JsonResponse from mysite.store import get_codes, get_url_with_token from mysite.utils import custom_init +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.asyncio import get_user +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import SessionClaimValidator +from supertokens_python.recipe.thirdparty import ProviderConfig +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) from supertokens_python.recipe.userroles import UserRoleClaim, PermissionClaim +from supertokens_python.types import AccountInfo, RecipeUserId +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.passwordless.asyncio import update_user +from supertokens_python.recipe.passwordless.interfaces import ( + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, +) mode = os.environ.get("APP_MODE", "asgi") @@ -85,7 +127,7 @@ async def set_role_api(request: HttpRequest): @verify_session() async def unverify_email_api(request: HttpRequest): session_: SessionContainer = request.supertokens # type: ignore - await unverify_email(session_.get_user_id()) + await unverify_email(session_.get_recipe_user_id()) await session_.fetch_and_set_claim(EmailVerificationClaim) return JsonResponse({"status": "OK"}) @@ -94,14 +136,16 @@ async def check_role_api(): # type: ignore return JsonResponse({"status": "OK"}) async def delete_user(request: HttpRequest): - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + from supertokens_python.asyncio import list_users_by_account_info from supertokens_python.asyncio import delete_user body = json.loads(request.body) - user = await get_user_by_email("public", body["email"]) - if user is None: + user = await list_users_by_account_info( + "public", AccountInfo(email=body["email"]) + ) + if len(user) == 0: raise Exception("Should not come here") - await delete_user(user.user_id) + await delete_user(user[0].id) return JsonResponse({"status": "OK"}) else: @@ -139,19 +183,19 @@ def sync_set_role_api(request: HttpRequest): @verify_session() def sync_unverify_email_api(request: HttpRequest): session_: SessionContainer = request.supertokens # type: ignore - sync_unverify_email(session_.get_user_id()) + sync_unverify_email(session_.get_recipe_user_id()) session_.sync_fetch_and_set_claim(EmailVerificationClaim) return JsonResponse({"status": "OK"}) def sync_delete_user(request: HttpRequest): - from supertokens_python.recipe.emailpassword.syncio import get_user_by_email + from supertokens_python.syncio import list_users_by_account_info from supertokens_python.syncio import delete_user body = json.loads(request.body) - user = get_user_by_email("public", body["email"]) - if user is None: + user = list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - delete_user(user.user_id) + delete_user(user[0].id) return JsonResponse({"status": "OK"}) @verify_session(override_global_claim_validators=override_global_claim_validators) @@ -176,16 +220,254 @@ def test_get_device(request: HttpRequest): return JsonResponse({"preAuthSessionId": pre_auth_session_id, "codes": codes}) -def test_set_flow(request: HttpRequest): +async def change_email(request: HttpRequest): body = json.loads(request.body) - contact_method = body["contactMethod"] - flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JsonResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JsonResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JsonResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JsonResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JsonResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +async def setup_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JsonResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +async def add_user_to_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JsonResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JsonResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JsonResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JsonResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JsonResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JsonResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +async def remove_user_from_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JsonResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +async def remove_tenant(request: HttpRequest): + body = json.loads(request.body) + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JsonResponse({"status": "OK", "didExist": core_resp.did_exist}) + + +async def test_set_flow(request: HttpRequest): + body = json.loads(request.body) + import mysite.store + + mysite.store.contact_method = body["contactMethod"] + mysite.store.flow_type = body["flowType"] + custom_init() + return HttpResponse("") + + +async def test_set_account_linking_config(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.accountlinking_config = body + custom_init() + return HttpResponse("") + + +async def set_mfa_info(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + mysite.store.mfa_info = body + return JsonResponse({"status": "OK"}) + + +@verify_session() +async def add_required_factor(request: HttpRequest): + session_: SessionContainer = request.supertokens # type: ignore + body = json.loads(request.body) + if body is None or "factorId" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session_.get_user_id(), body["factorId"] + ) + + return JsonResponse({"status": "OK"}) + + +def test_set_enabled_recipes(request: HttpRequest): + import mysite.store + + body = json.loads(request.body) + if body is None: + raise Exception("Invalid request body") + mysite.store.enabled_recipes = body.get("enabledRecipes") + mysite.store.enabled_providers = body.get("enabledProviders") + custom_init() return HttpResponse("") +def test_get_totp_code(request: HttpRequest): + from pyotp import TOTP + + body = json.loads(request.body) + if body is None or "secret" not in body: + return JsonResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JsonResponse({"totp": code}) + + def before_each(request: HttpRequest): - setattr(settings, "CODE_STORE", dict()) + import mysite.store + + mysite.store.contact_method = "EMAIL_OR_PHONE" + mysite.store.flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + mysite.store.latest_url_with_token = "" + mysite.store.code_store = dict() + mysite.store.accountlinking_config = {} + mysite.store.enabled_providers = None + mysite.store.enabled_recipes = None + mysite.store.mfa_info = {} custom_init() return HttpResponse("") @@ -199,6 +481,11 @@ def test_feature_flags(request: HttpRequest): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] } ) diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index 67fdcd407..9d6869626 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -13,7 +13,7 @@ # under the License. import os import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union import uvicorn # type: ignore from dotenv import load_dotenv @@ -26,12 +26,14 @@ from starlette.responses import Response from starlette.types import ASGIApp from typing_extensions import Literal -from supertokens_python.recipe import multitenancy +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe import multifactorauth, multitenancy, totp from supertokens_python import ( InputAppInfo, Supertokens, SupertokensConfig, + convert_to_recipe_user_id, get_all_cors_headers, init, ) @@ -45,10 +47,18 @@ thirdparty, userroles, ) +from supertokens_python.recipe import accountlinking +from supertokens_python.recipe.accountlinking import AccountInfoWithRecipeIdAndUserId +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe +from supertokens_python.recipe.emailpassword.asyncio import update_email_or_password from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EmailPasswordAPIInterface, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, ) from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EPAPIOptions, @@ -72,18 +82,51 @@ APIOptions as EVAPIOptions, ) from supertokens_python.recipe.jwt import JWTRecipe +from supertokens_python.recipe.multifactorauth.asyncio import ( + add_to_required_secondary_factors_for_user, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.multitenancy.asyncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.passwordless import ( ContactEmailOnlyConfig, ContactEmailOrPhoneConfig, ContactPhoneOnlyConfig, PasswordlessRecipe, ) +from supertokens_python.recipe.passwordless.asyncio import update_user from supertokens_python.recipe.passwordless.interfaces import ( APIInterface as PasswordlessAPIInterface, + PhoneNumberChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionContainer, SessionRecipe from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) from supertokens_python.recipe.session.framework.fastapi import verify_session from supertokens_python.recipe.session.interfaces import ( APIInterface as SessionAPIInterface, @@ -91,13 +134,19 @@ from supertokens_python.recipe.session.interfaces import APIOptions as SAPIOptions from supertokens_python.recipe.session.interfaces import SessionClaimValidator from supertokens_python.recipe.thirdparty import ( + ProviderConfig, ThirdPartyRecipe, ) +from supertokens_python.recipe.thirdparty.asyncio import manually_create_or_update_user from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import ( PermissionClaim, @@ -108,8 +157,13 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import GeneralErrorResponse -from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email +from supertokens_python.types import ( + AccountInfo, + GeneralErrorResponse, + RecipeUserId, + User, +) +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.asyncio import delete_user load_dotenv() @@ -119,6 +173,14 @@ os.environ.setdefault("SUPERTOKENS_ENV", "testing") code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None class CustomPlessEmailService( @@ -145,9 +207,7 @@ async def send_email( code_store[template_vars.pre_auth_session_id] = codes -class CustomPlessSMSService( - passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars] -): +class CustomSMSService(passwordless.SMSDeliveryInterface[passwordless.SMSTemplateVars]): async def send_sms( self, template_vars: passwordless.SMSTemplateVars, user_context: Dict[str, Any] ) -> None: @@ -225,7 +285,7 @@ def get_website_domain(): return "http://localhost:" + get_website_port() -latest_url_with_token = None +latest_url_with_token = "" async def validate_age(value: Any, tenant_id: str): @@ -265,75 +325,48 @@ async def get_user_info( # pylint: disable=no-self-use return oi -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + global contact_method + global flow_type + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() EmailVerificationRecipe.reset() SessionRecipe.reset() ThirdPartyRecipe.reset() - EmailVerificationRecipe.reset() EmailPasswordRecipe.reset() + EmailVerificationRecipe.reset() DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() - - providers_list: List[thirdparty.ProviderInput] = [ - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="google", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GOOGLE_CLIENT_ID"], - client_secret=os.environ["GOOGLE_CLIENT_SECRET"], - ), - ], - ), - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="github", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["GITHUB_CLIENT_ID"], - client_secret=os.environ["GITHUB_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="facebook", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["FACEBOOK_CLIENT_ID"], - client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], - ), - ], - ) - ), - thirdparty.ProviderInput( - config=thirdparty.ProviderConfig( - third_party_id="auth0", - name="Auth0", - authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", - authorization_endpoint_query_params={"scope": "openid profile"}, - token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", - clients=[ - thirdparty.ProviderClientConfig( - client_id=os.environ["AUTH0_CLIENT_ID"], - client_secret=os.environ["AUTH0_CLIENT_SECRET"], - ) - ], - ), - override=auth0_provider_override, - ), - ] + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -358,7 +391,11 @@ async def email_verify_post( if is_general_error: return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( - token, session, tenant_id, api_options, user_context + token, + session, + tenant_id, + api_options, + user_context, ) async def generate_email_verify_token_post( @@ -374,9 +411,7 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - session, - api_options, - user_context, + session, api_options, user_context ) original_implementation_email_verification.email_verify_post = email_verify_post @@ -447,6 +482,8 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -460,12 +497,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -475,7 +519,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -495,6 +544,8 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -508,6 +559,8 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -538,7 +591,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): @@ -562,6 +615,8 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -576,6 +631,8 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -584,6 +641,8 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -594,12 +653,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -610,7 +677,13 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post @@ -618,12 +691,87 @@ async def resend_code_post( original_implementation.resend_code_post = resend_code_post return original_implementation + providers_list: List[thirdparty.ProviderInput] = [ + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="google", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GOOGLE_CLIENT_ID"], + client_secret=os.environ["GOOGLE_CLIENT_SECRET"], + ), + ], + ), + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="github", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["GITHUB_CLIENT_ID"], + client_secret=os.environ["GITHUB_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="facebook", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["FACEBOOK_CLIENT_ID"], + client_secret=os.environ["FACEBOOK_CLIENT_SECRET"], + ), + ], + ) + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="auth0", + name="Auth0", + authorization_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/authorize", + authorization_endpoint_query_params={"scope": "openid profile"}, + token_endpoint=f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id=os.environ["AUTH0_CLIENT_ID"], + client_secret=os.environ["AUTH0_CLIENT_SECRET"], + ) + ], + ), + override=auth0_provider_override, + ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), + ] + + global enabled_providers + if enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in enabled_providers + ] + if contact_method is not None and flow_type is not None: if contact_method == "PHONE": passwordless_init = passwordless.init( contact_config=ContactPhoneOnlyConfig(), flow_type=flow_type, - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -646,7 +794,7 @@ async def resend_code_post( email_delivery=passwordless.EmailDeliveryConfig( CustomPlessEmailService() ), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig( apis=override_passwordless_apis ), @@ -656,7 +804,7 @@ async def resend_code_post( contact_config=ContactEmailOrPhoneConfig(), flow_type="USER_INPUT_CODE_AND_MAGIC_LINK", email_delivery=passwordless.EmailDeliveryConfig(CustomPlessEmailService()), - sms_delivery=passwordless.SMSDeliveryConfig(CustomPlessSMSService()), + sms_delivery=passwordless.SMSDeliveryConfig(CustomSMSService()), override=passwordless.InputOverrideConfig(apis=override_passwordless_apis), ) @@ -665,32 +813,243 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + global mfa_info + + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mfa_info: + return mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mfa_info: + if factor_id not in mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mfa_info: + return mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mfa_info: + res.factors.already_setup = mfa_info["alreadySetup"][:] + + if "noContacts" in mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailpassword.EmailDeliveryConfig(CustomEPEmailService()), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), + ), + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + + global accountlinking_config + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + global enabled_recipes + if enabled_recipes is not None: + recipe_list = [ + item["init"] for item in recipe_list if item["id"] in enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( @@ -719,20 +1078,272 @@ async def exception_handler(a, b): # type: ignore @app.post("/beforeeach") def before_each(): global code_store + global accountlinking_config + global enabled_providers + global enabled_recipes + global mfa_info + global latest_url_with_token + global contact_method + global flow_type + contact_method = "EMAIL_OR_PHONE" + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + latest_url_with_token = "" code_store = dict() + accountlinking_config = {} + enabled_providers = None + enabled_recipes = None + mfa_info = {} custom_init() return PlainTextResponse("") +@app.post("/changeEmail") +async def change_email(request: Request): + body: Union[dict[str, Any], None] = await request.json() + if body is None: + raise Exception("Should never come here") + + if body["rid"] == "emailpassword": + resp = await update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return JSONResponse(resp.to_json()) + elif body["rid"] == "thirdparty": + user = await get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = await manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return JSONResponse( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return JSONResponse( + {"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason} + ) + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = await update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return JSONResponse({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return JSONResponse( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + if isinstance(resp, PhoneNumberChangeNotAllowedError): + return JSONResponse( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": resp.reason, + } + ) + + raise Exception("Should not come here") + + +@app.post("/setupTenant") +async def setup_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = body.get("coreConfig", {}) + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = await create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + await create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return JSONResponse({"status": "OK", "createdNew": core_resp.created_new}) + + +@app.post("/addUserToTenant") +async def add_user_to_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return JSONResponse( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return JSONResponse({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return JSONResponse({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return JSONResponse({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return JSONResponse({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return JSONResponse( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +@app.post("/removeUserFromTenant") +async def remove_user_from_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = await disassociate_user_from_tenant( + tenant_id, RecipeUserId(recipe_user_id) + ) + + return JSONResponse({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +@app.post("/removeTenant") +async def remove_tenant(request: Request): + body = await request.json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = await delete_tenant(tenant_id) + + return JSONResponse({"status": "OK", "didExist": core_resp.did_exist}) + + @app.post("/test/setFlow") async def test_set_flow(request: Request): body = await request.json() + global contact_method + global flow_type contact_method = body["contactMethod"] flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + custom_init() return PlainTextResponse("") +@app.post("/test/setAccountLinkingConfig") +async def test_set_account_linking_config(request: Request): + global accountlinking_config + body = await request.json() + if body is None: + raise Exception("Invalid request body") + accountlinking_config = body + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/setMFAInfo") +async def set_mfa_info(request: Request): + global mfa_info + body = await request.json() + if body is None: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + mfa_info = body + return JSONResponse({"status": "OK"}) + + +@app.post("/addRequiredFactor") +async def add_required_factor( + request: Request, session: SessionContainer = Depends(verify_session()) +): + body = await request.json() + if body is None or "factorId" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + await add_to_required_secondary_factors_for_user( + session.get_user_id(), body["factorId"] + ) + + return JSONResponse({"status": "OK"}) + + +@app.post("/test/setEnabledRecipes") +async def test_set_enabled_recipes(request: Request): + global enabled_recipes + global enabled_providers + body = await request.json() + if body is None: + raise Exception("Invalid request body") + enabled_recipes = body.get("enabledRecipes") + enabled_providers = body.get("enabledProviders") + custom_init() + return PlainTextResponse("", status_code=200) + + +@app.post("/test/getTOTPCode") +async def test_get_totp_code(request: Request): + from pyotp import TOTP + + body = await request.json() + if body is None or "secret" not in body: + return JSONResponse({"error": "Invalid request body"}, status_code=400) + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return JSONResponse({"totp": code}) + + @app.get("/test/getDevice") def test_get_device(request: Request): global code_store @@ -751,6 +1362,11 @@ def test_feature_flags(request: Request): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] return JSONResponse({"available": available}) @@ -780,7 +1396,7 @@ async def get_token(): @app.get("/unverifyEmail") async def unverify_email_api(session_: SessionContainer = Depends(verify_session())): - await unverify_email(session_.get_user_id()) + await unverify_email(session_.get_recipe_user_id()) await session_.fetch_and_set_claim(EmailVerificationClaim) return JSONResponse({"status": "OK"}) @@ -800,10 +1416,10 @@ async def set_role_api( @app.post("/deleteUser") async def delete_user_api(request: Request): body = await request.json() - user = await get_user_by_email("public", body["email"]) - if user is None: + user = await list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - await delete_user(user.user_id) + await delete_user(user[0].id) return JSONResponse({"status": "OK"}) @@ -884,4 +1500,4 @@ def preflight_response(self, request_headers: Headers) -> Response: ) if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=get_api_port()) # type: ignore + uvicorn.run(app, host="0.0.0.0", port=int(get_api_port())) # type: ignore diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 07336dc8f..3ee60c66d 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -12,7 +12,8 @@ # License for the specific language governing permissions and limitations # under the License. import os -from typing import Any, Dict, List, Optional, Union +import traceback +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv from flask import Flask, g, jsonify, make_response, request @@ -26,29 +27,70 @@ get_all_cors_headers, init, ) +from supertokens_python.auth_utils import LinkingToSessionUserFailedError from supertokens_python.framework.flask.flask_middleware import Middleware from supertokens_python.framework.request import BaseRequest from supertokens_python.recipe import ( + accountlinking, emailpassword, emailverification, passwordless, session, thirdparty, + totp, userroles, ) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.types import ( + AccountInfoWithRecipeIdAndUserId, +) from supertokens_python.recipe.dashboard import DashboardRecipe from supertokens_python.recipe.emailpassword import EmailPasswordRecipe from supertokens_python.recipe.emailpassword.interfaces import ( APIInterface as EmailPasswordAPIInterface, + EmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, + UpdateEmailOrPasswordOkResult, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.multifactorauth.syncio import ( + add_to_required_secondary_factors_for_user, +) +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.multitenancy.interfaces import ( + AssociateUserToTenantEmailAlreadyExistsError, + AssociateUserToTenantOkResult, + AssociateUserToTenantPhoneNumberAlreadyExistsError, + AssociateUserToTenantThirdPartyUserAlreadyExistsError, + AssociateUserToTenantUnknownUserIdError, + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.multitenancy.syncio import ( + associate_user_to_tenant, + create_or_update_tenant, + create_or_update_third_party_config, + delete_tenant, + disassociate_user_from_tenant, +) +from supertokens_python.recipe.passwordless.syncio import update_user +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) +from supertokens_python.recipe.thirdparty.provider import ( + Provider, + RedirectUriInfo, ) -from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo from supertokens_python.recipe.emailpassword.interfaces import ( APIOptions as EPAPIOptions, ) from supertokens_python.recipe.emailpassword.types import ( FormField, InputFormField, - User, ) from supertokens_python.recipe.emailverification import ( EmailVerificationClaim, @@ -73,6 +115,11 @@ ) from supertokens_python.recipe.passwordless.interfaces import ( APIInterface as PasswordlessAPIInterface, + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, ) from supertokens_python.recipe.passwordless.interfaces import APIOptions as PAPIOptions from supertokens_python.recipe.session import SessionRecipe @@ -87,12 +134,19 @@ SessionClaimValidator, SessionContainer, ) -from supertokens_python.recipe.thirdparty import ThirdPartyRecipe +from supertokens_python.recipe.thirdparty import ( + ProviderConfig, + ThirdPartyRecipe, +) from supertokens_python.recipe.thirdparty.interfaces import ( APIInterface as ThirdpartyAPIInterface, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, ) from supertokens_python.recipe.thirdparty.interfaces import APIOptions as TPAPIOptions from supertokens_python.recipe.thirdparty.provider import Provider +from supertokens_python.recipe.thirdparty.syncio import manually_create_or_update_user +from supertokens_python.recipe.totp.recipe import TOTPRecipe from supertokens_python.recipe.userroles import ( PermissionClaim, @@ -103,9 +157,14 @@ add_role_to_user, create_new_role_or_add_permissions, ) -from supertokens_python.types import GeneralErrorResponse -from supertokens_python.recipe.emailpassword.syncio import get_user_by_email -from supertokens_python.syncio import delete_user +from supertokens_python.types import ( + AccountInfo, + RecipeUserId, + User, + GeneralErrorResponse, +) +from supertokens_python.syncio import delete_user, get_user, list_users_by_account_info +from supertokens_python.recipe import multifactorauth load_dotenv() @@ -124,9 +183,17 @@ def get_website_domain(): os.environ.setdefault("SUPERTOKENS_ENV", "testing") -latest_url_with_token = None +latest_url_with_token = "" code_store: Dict[str, List[Dict[str, Any]]] = {} +accountlinking_config: Dict[str, Any] = {} +enabled_providers: Optional[List[Any]] = None +enabled_recipes: Optional[List[Any]] = None +mfa_info: Dict[str, Any] = {} +contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None +flow_type: Union[ + None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] +] = None class CustomPlessEmailService( @@ -264,12 +331,35 @@ async def get_user_info( # pylint: disable=no-self-use return oi -def custom_init( - contact_method: Union[None, Literal["PHONE", "EMAIL", "EMAIL_OR_PHONE"]] = None, - flow_type: Union[ - None, Literal["USER_INPUT_CODE", "MAGIC_LINK", "USER_INPUT_CODE_AND_MAGIC_LINK"] - ] = None, -): +def mock_provider_override(oi: Provider) -> Provider: + async def get_user_info( + oauth_tokens: Dict[str, Any], + user_context: Dict[str, Any], + ) -> UserInfo: + user_id = oauth_tokens.get("userId", "user") + email = oauth_tokens.get("email", "email@test.com") + is_verified = oauth_tokens.get("isVerified", "true").lower() != "false" + + return UserInfo( + user_id, UserInfoEmail(email, is_verified), raw_user_info_from_provider=None + ) + + async def exchange_auth_code_for_oauth_tokens( + redirect_uri_info: RedirectUriInfo, + user_context: Dict[str, Any], + ) -> Dict[str, Any]: + return redirect_uri_info.redirect_uri_query_params + + oi.exchange_auth_code_for_oauth_tokens = exchange_auth_code_for_oauth_tokens + oi.get_user_info = get_user_info + return oi + + +def custom_init(): + global contact_method + global flow_type + + AccountLinkingRecipe.reset() UserRolesRecipe.reset() PasswordlessRecipe.reset() JWTRecipe.reset() @@ -281,6 +371,8 @@ def custom_init( DashboardRecipe.reset() MultitenancyRecipe.reset() Supertokens.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() def override_email_verification_apis( original_implementation_email_verification: EmailVerificationAPIInterface, @@ -396,6 +488,8 @@ async def password_reset_post( async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -409,12 +503,19 @@ async def sign_in_post( msg = body["generalErrorMessage"] return GeneralErrorResponse(msg) return await original_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) async def sign_up_post( form_fields: List[FormField], tenant_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: EPAPIOptions, user_context: Dict[str, Any], ): @@ -424,7 +525,12 @@ async def sign_up_post( if is_general_error: return GeneralErrorResponse("general error from API sign up") return await original_sign_up_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) original_implementation.email_exists_get = email_exists_get @@ -444,6 +550,8 @@ async def sign_in_up_post( provider: Provider, redirect_uri_info: Union[RedirectUriInfo, None], oauth_tokens: Union[Dict[str, Any], None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: TPAPIOptions, user_context: Dict[str, Any], @@ -457,6 +565,8 @@ async def sign_in_up_post( provider, redirect_uri_info, oauth_tokens, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -487,7 +597,7 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - session: Optional[SessionContainer], + session: SessionContainer, api_options: SAPIOptions, user_context: Dict[str, Any], ): @@ -511,6 +621,8 @@ async def consume_code_post( user_input_code: Union[str, None], device_id: Union[str, None], link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -525,6 +637,8 @@ async def consume_code_post( user_input_code, device_id, link_code, + session, + should_try_linking_with_session_user, tenant_id, api_options, user_context, @@ -533,6 +647,8 @@ async def consume_code_post( async def create_code_post( email: Union[str, None], phone_number: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -543,12 +659,20 @@ async def create_code_post( if is_general_error: return GeneralErrorResponse("general error from API create code") return await original_create_code_post( - email, phone_number, tenant_id, api_options, user_context + email, + phone_number, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) async def resend_code_post( device_id: str, pre_auth_session_id: str, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], tenant_id: str, api_options: PAPIOptions, user_context: Dict[str, Any], @@ -559,7 +683,13 @@ async def resend_code_post( if is_general_error: return GeneralErrorResponse("general error from API resend code") return await original_resend_code_post( - device_id, pre_auth_session_id, tenant_id, api_options, user_context + device_id, + pre_auth_session_id, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, ) original_implementation.consume_code_post = consume_code_post @@ -617,8 +747,31 @@ async def resend_code_post( ), override=auth0_provider_override, ), + thirdparty.ProviderInput( + config=thirdparty.ProviderConfig( + third_party_id="mock-provider", + name="Mock Provider", + authorization_endpoint=get_website_domain() + "/mockProvider/auth", + token_endpoint=get_website_domain() + "/mockProvider/token", + clients=[ + thirdparty.ProviderClientConfig( + client_id="supertokens", + client_secret="", + ) + ], + ), + override=mock_provider_override, + ), ] + global enabled_providers + if enabled_providers is not None: + providers_list = [ + provider + for provider in providers_list + if provider.config.third_party_id in enabled_providers + ] + if contact_method is not None and flow_type is not None: if contact_method == "PHONE": passwordless_init = passwordless.init( @@ -666,33 +819,243 @@ async def get_allowed_domains_for_tenant_id( ) -> List[str]: return [tenant_id + ".example.com", "localhost"] - recipe_list = [ - userroles.init(), - session.init(override=session.InputOverrideConfig(apis=override_session_apis)), - emailverification.init( - mode="OPTIONAL", - email_delivery=emailverification.EmailDeliveryConfig( - CustomEVEmailService() + global mfa_info + + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + APIInterface as MFAApiInterface, + APIOptions as MFAApiOptions, + ) + + def override_mfa_functions(original_implementation: MFARecipeInterface): + og_get_factors_setup_for_user = ( + original_implementation.get_factors_setup_for_user + ) + + async def get_factors_setup_for_user( + user: User, + user_context: Dict[str, Any], + ): + res = await og_get_factors_setup_for_user(user, user_context) + if "alreadySetup" in mfa_info: + return mfa_info["alreadySetup"] + return res + + og_assert_allowed_to_setup_factor = ( + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + async def assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session: SessionContainer, + factor_id: str, + mfa_requirements_for_auth: Callable[[], Awaitable[MFARequirementList]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ): + if "allowedToSetup" in mfa_info: + if factor_id not in mfa_info["allowedToSetup"]: + raise InvalidClaimsError( + msg="INVALID_CLAIMS", + payload=[ + ClaimValidationError(id_="test", reason="test override") + ], + ) + else: + await og_assert_allowed_to_setup_factor( + session, + factor_id, + mfa_requirements_for_auth, + factors_set_up_for_user, + user_context, + ) + + og_get_mfa_requirements_for_auth = ( + original_implementation.get_mfa_requirements_for_auth + ) + + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Callable[[], Awaitable[User]], + factors_set_up_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_user: Callable[[], Awaitable[List[str]]], + required_secondary_factors_for_tenant: Callable[[], Awaitable[List[str]]], + user_context: Dict[str, Any], + ) -> MFARequirementList: + res = await og_get_mfa_requirements_for_auth( + tenant_id, + access_token_payload, + completed_factors, + user, + factors_set_up_for_user, + required_secondary_factors_for_user, + required_secondary_factors_for_tenant, + user_context, + ) + if "requirements" in mfa_info: + return mfa_info["requirements"] + return res + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + + original_implementation.assert_allowed_to_setup_factor_else_throw_invalid_claim_error = ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error + ) + + original_implementation.get_factors_setup_for_user = get_factors_setup_for_user + return original_implementation + + def override_mfa_apis(original_implementation: MFAApiInterface): + og_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAApiOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + res = await og_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + if isinstance(res, ResyncSessionAndFetchMFAInfoPUTOkResult): + if "alreadySetup" in mfa_info: + res.factors.already_setup = mfa_info["alreadySetup"][:] + + if "noContacts" in mfa_info: + res.emails = {} + res.phone_numbers = {} + + return res + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + recipe_list: List[Any] = [ + {"id": "userroles", "init": userroles.init()}, + { + "id": "session", + "init": session.init( + override=session.InputOverrideConfig(apis=override_session_apis) ), - override=EVInputOverrideConfig(apis=override_email_verification_apis), - ), - emailpassword.init( - sign_up_feature=emailpassword.InputSignUpFeature(form_fields), - email_delivery=emailpassword.EmailDeliveryConfig(CustomEPEmailService()), - override=emailpassword.InputOverrideConfig( - apis=override_email_password_apis, + }, + { + "id": "emailverification", + "init": emailverification.init( + mode="OPTIONAL", + email_delivery=emailverification.EmailDeliveryConfig( + CustomEVEmailService() + ), + override=EVInputOverrideConfig(apis=override_email_verification_apis), ), - ), - thirdparty.init( - sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), - override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), - ), - passwordless_init, - multitenancy.init( - get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id - ), + }, + { + "id": "emailpassword", + "init": emailpassword.init( + sign_up_feature=emailpassword.InputSignUpFeature(form_fields), + email_delivery=emailpassword.EmailDeliveryConfig( + CustomEPEmailService() + ), + override=emailpassword.InputOverrideConfig( + apis=override_email_password_apis, + ), + ), + }, + { + "id": "thirdparty", + "init": thirdparty.init( + sign_in_and_up_feature=thirdparty.SignInAndUpFeature(providers_list), + override=thirdparty.InputOverrideConfig(apis=override_thirdparty_apis), + ), + }, + { + "id": "passwordless", + "init": passwordless_init, + }, + { + "id": "multitenancy", + "init": multitenancy.init( + get_allowed_domains_for_tenant_id=get_allowed_domains_for_tenant_id + ), + }, + { + "id": "multifactorauth", + "init": multifactorauth.init( + first_factors=mfa_info.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_mfa_functions, + apis=override_mfa_apis, + ), + ), + }, + { + "id": "totp", + "init": totp.init( + config=totp.TOTPConfig( + default_period=1, + default_skew=30, + ) + ), + }, ] + global accountlinking_config + + accountlinking_config_input = { + "enabled": False, + "shouldAutoLink": { + "shouldAutomaticallyLink": True, + "shouldRequireVerification": True, + }, + **accountlinking_config, + } + + async def should_do_automatic_account_linking( + _: AccountInfoWithRecipeIdAndUserId, + __: Optional[User], + ___: Optional[SessionContainer], + ____: str, + _____: Dict[str, Any], + ) -> Union[ + accountlinking.ShouldNotAutomaticallyLink, + accountlinking.ShouldAutomaticallyLink, + ]: + should_auto_link = accountlinking_config_input["shouldAutoLink"] + assert isinstance(should_auto_link, dict) + should_automatically_link = should_auto_link["shouldAutomaticallyLink"] + assert isinstance(should_automatically_link, bool) + if should_automatically_link: + should_require_verification = should_auto_link["shouldRequireVerification"] + assert isinstance(should_require_verification, bool) + return accountlinking.ShouldAutomaticallyLink( + should_require_verification=should_require_verification + ) + return accountlinking.ShouldNotAutomaticallyLink() + + if accountlinking_config_input["enabled"]: + recipe_list.append( + { + "id": "accountlinking", + "init": accountlinking.init( + should_do_automatic_account_linking=should_do_automatic_account_linking + ), + } + ) + + global enabled_recipes + if enabled_recipes is not None: + recipe_list = [ + item["init"] for item in recipe_list if item["id"] in enabled_recipes + ] + else: + recipe_list = [item["init"] for item in recipe_list] + init( supertokens_config=SupertokensConfig("http://localhost:9000"), app_info=InputAppInfo( @@ -731,6 +1094,174 @@ def ping(): return "success" +@app.route("/changeEmail", methods=["POST"]) # type: ignore +def change_email(): + body: Union[Any, None] = request.get_json() + if body is None: + raise Exception("Should never come here") + from supertokens_python.recipe.emailpassword.syncio import update_email_or_password + from supertokens_python import convert_to_recipe_user_id + + if body["rid"] == "emailpassword": + resp = update_email_or_password( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body["email"], + tenant_id_for_password_policy=body["tenantId"], + ) + if isinstance(resp, UpdateEmailOrPasswordOkResult): + return jsonify({"status": "OK"}) + if isinstance(resp, EmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + # password policy violation error + return jsonify(resp.to_json()) + elif body["rid"] == "thirdparty": + user = get_user(user_id=body["recipeUserId"]) + assert user is not None + login_method = next( + lm + for lm in user.login_methods + if lm.recipe_user_id.get_as_string() == body["recipeUserId"] + ) + assert login_method is not None + assert login_method.third_party is not None + resp = manually_create_or_update_user( + tenant_id=body["tenantId"], + third_party_id=login_method.third_party.id, + third_party_user_id=login_method.third_party.user_id, + email=body["email"], + is_verified=False, + ) + if isinstance(resp, ManuallyCreateOrUpdateUserOkResult): + return jsonify( + {"status": "OK", "createdNewRecipeUser": resp.created_new_recipe_user} + ) + if isinstance(resp, LinkingToSessionUserFailedError): + raise Exception("Should not come here") + if isinstance(resp, SignInUpNotAllowed): + return jsonify({"status": "SIGN_IN_UP_NOT_ALLOWED", "reason": resp.reason}) + # EmailChangeNotAllowedError + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + elif body["rid"] == "passwordless": + resp = update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + ) + + if isinstance(resp, UpdateUserOkResult): + return jsonify({"status": "OK"}) + if isinstance(resp, UpdateUserUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + if isinstance(resp, UpdateUserEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, UpdateUserPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + if isinstance(resp, EmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + return jsonify( + {"status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", "reason": resp.reason} + ) + + raise Exception("Should not come here") + + +@app.route("/setupTenant", methods=["POST"]) # type: ignore +def setup_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + login_methods = body["loginMethods"] + core_config = "coreConfig" in body and body["coreConfig"] or {} + + first_factors: List[str] = [] + if login_methods.get("emailPassword", {}).get("enabled") == True: + first_factors.append("emailpassword") + if login_methods.get("thirdParty", {}).get("enabled") == True: + first_factors.append("thirdparty") + if login_methods.get("passwordless", {}).get("enabled") == True: + first_factors.extend(["otp-phone", "otp-email", "link-phone", "link-email"]) + + core_resp = create_or_update_tenant( + tenant_id, + config=TenantConfigCreateOrUpdate( + first_factors=first_factors, + core_config=core_config, + ), + ) + + if login_methods.get("thirdParty", {}).get("providers") is not None: + for provider in login_methods["thirdParty"]["providers"]: + create_or_update_third_party_config( + tenant_id, + config=ProviderConfig.from_json(provider), + ) + + return jsonify({"status": "OK", "createdNew": core_resp.created_new}) + + +@app.route("/addUserToTenant", methods=["POST"]) # type: ignore +def add_user_to_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = associate_user_to_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + if isinstance(core_resp, AssociateUserToTenantOkResult): + return jsonify( + {"status": "OK", "wasAlreadyAssociated": core_resp.was_already_associated} + ) + elif isinstance(core_resp, AssociateUserToTenantUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(core_resp, AssociateUserToTenantThirdPartyUserAlreadyExistsError): + return jsonify({"status": "THIRD_PARTY_USER_ALREADY_EXISTS_ERROR"}) + return jsonify( + {"status": "ASSOCIATION_NOT_ALLOWED_ERROR", "reason": core_resp.reason} + ) + + +@app.route("/removeUserFromTenant", methods=["POST"]) # type: ignore +def remove_user_from_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + recipe_user_id = body["recipeUserId"] + + core_resp = disassociate_user_from_tenant(tenant_id, RecipeUserId(recipe_user_id)) + + return jsonify({"status": "OK", "wasAssociated": core_resp.was_associated}) + + +@app.route("/removeTenant", methods=["POST"]) # type: ignore +def remove_tenant(): + body = request.get_json() + if body is None: + raise Exception("Should never come here") + tenant_id = body["tenantId"] + + core_resp = delete_tenant(tenant_id) + + return jsonify({"status": "OK", "didExist": core_resp.did_exist}) + + @app.route("/sessionInfo", methods=["GET"]) # type: ignore @verify_session() def get_session_info(): @@ -754,7 +1285,21 @@ def get_token(): @app.route("/beforeeach", methods=["POST"]) # type: ignore def before_each(): global code_store + global accountlinking_config + global enabled_providers + global enabled_recipes + global mfa_info + global latest_url_with_token + global contact_method + global flow_type + contact_method = "EMAIL_OR_PHONE" + flow_type = "USER_INPUT_CODE_AND_MAGIC_LINK" + latest_url_with_token = "" code_store = dict() + accountlinking_config = {} + enabled_providers = None + enabled_recipes = None + mfa_info = {} custom_init() return "" @@ -764,12 +1309,77 @@ def test_set_flow(): body: Union[Any, None] = request.get_json() if body is None: raise Exception("Should never come here") + global contact_method + global flow_type contact_method = body["contactMethod"] flow_type = body["flowType"] - custom_init(contact_method=contact_method, flow_type=flow_type) + custom_init() return "" +@app.route("/test/setAccountLinkingConfig", methods=["POST"]) # type: ignore +def test_set_account_linking_config(): + global accountlinking_config + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + accountlinking_config = body + custom_init() + return "", 200 + + +@app.route("/setMFAInfo", methods=["POST"]) # type: ignore +def set_mfa_info(): + global mfa_info + body = request.get_json() + if body is None: + return jsonify({"error": "Invalid request body"}), 400 + mfa_info = body + return jsonify({"status": "OK"}) + + +@app.route("/addRequiredFactor", methods=["POST"]) # type: ignore +@verify_session() +def add_required_factor(): + session_: SessionContainer = g.supertokens # type: ignore + + body = request.get_json() + if body is None or "factorId" not in body: + return jsonify({"error": "Invalid request body"}), 400 + + add_to_required_secondary_factors_for_user(session_.get_user_id(), body["factorId"]) + + return jsonify({"status": "OK"}) + + +@app.route("/test/setEnabledRecipes", methods=["POST"]) # type: ignore +def test_set_enabled_recipes(): + global enabled_recipes + global enabled_providers + body = request.get_json() + if body is None: + raise Exception("Invalid request body") + enabled_recipes = body.get("enabledRecipes") + enabled_providers = body.get("enabledProviders") + custom_init() + return "", 200 + + +@app.route("/test/getTOTPCode", methods=["POST"]) # type: ignore +def test_get_totp_code(): + from pyotp import TOTP + + body = request.get_json() + if body is None or "secret" not in body: + return jsonify({"error": "Invalid request body"}), 400 + + secret = body["secret"] + totp = TOTP(secret, digits=6, interval=1) + code = totp.now() + + return jsonify({"totp": code}) + + @app.get("/test/getDevice") # type: ignore def test_get_device(): global code_store @@ -788,6 +1398,11 @@ def test_feature_flags(): "generalerror", "userroles", "multitenancy", + "multitenancyManagementEndpoints", + "accountlinking", + "mfa", + "recipeConfig", + "accountlinking-fixes", ] return jsonify({"available": available}) @@ -796,7 +1411,7 @@ def test_feature_flags(): @verify_session() def unverify_email_api(): session_: SessionContainer = g.supertokens # type: ignore - unverify_email(session_.get_user_id()) + unverify_email(session_.get_recipe_user_id()) session_.sync_fetch_and_set_claim(EmailVerificationClaim) return jsonify({"status": "OK"}) @@ -816,10 +1431,10 @@ def verify_email_api(): @app.route("/deleteUser", methods=["POST"]) # type: ignore def delete_user_api(): body: Dict[str, Any] = request.get_json() # type: ignore - user = get_user_by_email("public", body["email"]) - if user is None: + user = list_users_by_account_info("public", AccountInfo(email=body["email"])) + if len(user) == 0: raise Exception("Should not come here") - delete_user(user.user_id) + delete_user(user[0].id) return jsonify({"status": "OK"}) @@ -859,6 +1474,8 @@ def index(_: str): @app.errorhandler(Exception) # type: ignore def all_exception_handler(e: Exception): + print(e) + print(traceback.format_exc()) return "Error", 500 diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index dc8320332..ab8e43763 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -2,6 +2,9 @@ from fastapi import FastAPI from pytest import fixture, mark +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, +) from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import init from supertokens_python.constants import DASHBOARD_VERSION @@ -133,9 +136,9 @@ async def should_allow_access( res = app.get(url="/auth/dashboard/api/users?limit=5") body = res.json() assert res.status_code == 200 - assert body["users"][0]["user"]["firstName"] == "User2" - assert body["users"][1]["user"]["lastName"] == "Foo" - assert body["users"][1]["user"]["firstName"] == "User1" + assert body["users"][0]["firstName"] == "User2" + assert body["users"][1]["lastName"] == "Foo" + assert body["users"][1]["firstName"] == "User1" async def test_connection_uri_has_http_prefix_if_localhost(app: TestClient): @@ -248,9 +251,11 @@ async def should_allow_access( start_st() pluser = await manually_create_or_update_user( - "public", "google", "googleid", "test@example.com" + "public", "google", "googleid", "test@example.com", True, None ) + assert isinstance(pluser, ManuallyCreateOrUpdateUserOkResult) + res = app.get( url="/auth/dashboard/api/user", params={ @@ -264,10 +269,10 @@ async def should_allow_access( res = app.get( url="/auth/dashboard/api/user", params={ - "userId": pluser.user.user_id, + "userId": pluser.user.id, "recipeId": "thirdparty", }, ) res_json = res.json() assert res_json["status"] == "OK" - assert res_json["user"]["id"] == pluser.user.user_id + assert res_json["user"]["id"] == pluser.user.id diff --git a/tests/emailpassword/test_emaildelivery.py b/tests/emailpassword/test_emaildelivery.py index 7530bdbcc..2e51e6df1 100644 --- a/tests/emailpassword/test_emaildelivery.py +++ b/tests/emailpassword/test_emaildelivery.py @@ -20,6 +20,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -61,10 +62,6 @@ from supertokens_python.recipe.emailpassword.asyncio import ( send_reset_password_email, ) -from supertokens_python.recipe.emailpassword.interfaces import ( - SendResetPasswordEmailOkResult, - SendResetPasswordEmailUnknownUserIdError, -) respx_mock = respx.MockRouter @@ -210,7 +207,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -254,13 +253,14 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email assert isinstance(template_vars, PasswordResetEmailTemplateVars) password_reset_url = template_vars.password_reset_link - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -323,11 +323,12 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): template_vars.user.email = "override@example.com" assert isinstance(template_vars, PasswordResetEmailTemplateVars) - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -336,7 +337,9 @@ class CustomEmailService( emailpassword.EmailDeliveryInterface[emailpassword.EmailTemplateVars] ): async def send_email( - self, template_vars: Any, user_context: Dict[str, Any] + self, + template_vars: Any, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -388,7 +391,10 @@ async def test_reset_password_smtp_service(driver_config_client: TestClient): def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[ # pylint: disable=unused-argument + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -400,7 +406,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -437,11 +444,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -490,7 +498,8 @@ async def test_reset_password_for_non_existent_user(driver_config_client: TestCl def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -502,7 +511,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -539,11 +549,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -617,7 +628,7 @@ async def test_email_verification_default_backward_compatibility( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -683,7 +694,7 @@ async def test_email_verification_default_backward_compatibility_suppress_error( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -731,7 +742,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailverification.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, email_verify_url email = template_vars.user.email @@ -766,7 +779,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) res = email_verify_token_request( @@ -796,7 +809,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, email_verify_url email = template_vars.user.email @@ -837,7 +851,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -882,7 +896,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email @@ -929,7 +945,8 @@ async def test_email_verification_smtp_service(driver_config_client: TestClient) def smtp_service_override(oi: SMTPServiceInterface[VerificationEmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -941,7 +958,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: VerificationEmailTemplateVars, _user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, email_verify_url get_content_called = True @@ -978,7 +996,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True @@ -1017,7 +1036,7 @@ async def send_email_override( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( @@ -1049,7 +1068,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: PasswordResetEmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal reset_url, token_info, tenant_info, query_length password_reset_url = template_vars.password_reset_link @@ -1085,8 +1106,10 @@ async def send_email( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - resp = await send_reset_password_email("public", user_info["id"]) - assert isinstance(resp, SendResetPasswordEmailOkResult) + resp = await send_reset_password_email( + "public", user_info["id"], "random@gmail.com" + ) + assert resp == "OK" assert reset_url == "http://supertokens.io/auth/reset-password" assert token_info is not None and "token=" in token_info @@ -1122,9 +1145,13 @@ async def test_send_reset_password_email_invalid_input( dict_response = json.loads(response_1.text) user_info = dict_response["user"] - link = await send_reset_password_email("public", "invalidUserId") - assert isinstance(link, SendResetPasswordEmailUnknownUserIdError) + link = await send_reset_password_email( + "public", "invalidUserId", "random@gmail.com" + ) + assert link == "UNKNOWN_USER_ID_ERROR" with raises(Exception) as err: - await send_reset_password_email("invalidTenantId", user_info["id"]) + await send_reset_password_email( + "invalidTenantId", user_info["id"], "random@gmail.com" + ) assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_emailexists.py b/tests/emailpassword/test_emailexists.py index a8425b027..748dde85d 100644 --- a/tests/emailpassword/test_emailexists.py +++ b/tests/emailpassword/test_emailexists.py @@ -16,6 +16,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -50,7 +51,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index bda081c8d..06314917e 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -19,6 +19,7 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -44,7 +45,7 @@ VerifyEmailUsingTokenInvalidTokenError, ) from supertokens_python.recipe.emailverification.types import ( - User as EVUser, + EmailVerificationUser as EVUser, ) from supertokens_python.recipe.emailverification.utils import OverrideConfig from supertokens_python.recipe.session import SessionContainer @@ -84,7 +85,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") @@ -197,7 +198,9 @@ async def test_the_generate_token_api_with_valid_input_email_verified_and_test_e user_id = dict_response["user"]["id"] cookies = extract_all_cookies(response_1) - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) @@ -741,7 +744,7 @@ async def email_verify_post( await asyncio.sleep(1) if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -971,7 +974,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1082,7 +1085,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1120,8 +1123,10 @@ async def test_the_generate_token_api_with_valid_input_and_then_remove_token( assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) - await revoke_email_verification_tokens("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) + await revoke_email_verification_tokens("public", RecipeUserId(user_id)) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): response = await verify_email_using_token("public", verify_token.token) @@ -1164,15 +1169,17 @@ async def test_the_generate_token_api_with_valid_input_verify_and_then_unverify_ assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) - assert await is_email_verified(user_id) + assert await is_email_verified(RecipeUserId(user_id)) - await unverify_email(user_id) + await unverify_email(RecipeUserId(user_id)) - is_verified = await is_email_verified(user_id) + is_verified = await is_email_verified(RecipeUserId(user_id)) assert is_verified is False return raise Exception("Test failed") @@ -1268,7 +1275,9 @@ async def send_email( cookies = extract_all_cookies(res) # Start verification: - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) assert isinstance(verify_token, CreateEmailVerificationTokenOkResult) await verify_email_using_token("public", verify_token.token) @@ -1295,10 +1304,10 @@ async def send_email( ) assert res.status_code == 200 assert res.json() == {"status": "EMAIL_ALREADY_VERIFIED_ERROR"} - assert "front-token" not in res.headers + assert res.headers.get("front-token") is not None # now we mark the email as unverified and try again: - await unverify_email(user_id) + await unverify_email(RecipeUserId(user_id)) res = email_verify_token_request( driver_config_client, cookies["sAccessToken"]["value"], @@ -1369,16 +1378,16 @@ def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: dict_response = json.loads(response_1.text) assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - email = dict_response["user"]["email"] + email = dict_response["user"]["emails"][0] await send_email_verification_email( - "public", user_id, email, {"url": "localhost:3000"} + "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3000"} ) url = urlparse(email_verify_link) assert url.netloc == "localhost:3000" await send_email_verification_email( - "public", user_id, email, {"url": "localhost:3002"} + "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3002"} ) url = urlparse(email_verify_link) assert url.netloc == "localhost:3002" diff --git a/tests/emailpassword/test_multitenancy.py b/tests/emailpassword/test_multitenancy.py index eaa87d5bd..9b7e521e0 100644 --- a/tests/emailpassword/test_multitenancy.py +++ b/tests/emailpassword/test_multitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.recipe import session, userroles, emailpassword, multitenancy from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,8 +21,6 @@ from supertokens_python.recipe.emailpassword.asyncio import ( sign_up, sign_in, - get_user_by_id, - get_user_by_email, create_reset_password_token, reset_password_using_token, ) @@ -30,7 +29,10 @@ SignInOkResult, CreateResetPasswordOkResult, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -62,9 +64,15 @@ async def test_multitenancy_in_emailpassword(): setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) user1 = await sign_up("t1", "test@example.com", "password1") user2 = await sign_up("t2", "test@example.com", "password2") @@ -74,9 +82,9 @@ async def test_multitenancy_in_emailpassword(): assert isinstance(user2, SignUpOkResult) assert isinstance(user3, SignUpOkResult) - assert user1.user.user_id != user2.user.user_id - assert user2.user.user_id != user3.user.user_id - assert user3.user.user_id != user1.user.user_id + assert user1.user.id != user2.user.id + assert user2.user.id != user3.user.id + assert user3.user.id != user1.user.id assert user1.user.tenant_ids == ["t1"] assert user2.user.tenant_ids == ["t2"] @@ -91,32 +99,44 @@ async def test_multitenancy_in_emailpassword(): assert isinstance(ep_user2, SignInOkResult) assert isinstance(ep_user3, SignInOkResult) - assert ep_user1.user.user_id == user1.user.user_id - assert ep_user2.user.user_id == user2.user.user_id - assert ep_user3.user.user_id == user3.user.user_id + assert ep_user1.user.id == user1.user.id + assert ep_user2.user.id == user2.user.id + assert ep_user3.user.id == user3.user.id # get user by id: - g_user1 = await get_user_by_id(user1.user.user_id) - g_user2 = await get_user_by_id(user2.user.user_id) - g_user3 = await get_user_by_id(user3.user.user_id) + g_user1 = await get_user(user1.user.id) + g_user2 = await get_user(user2.user.id) + g_user3 = await get_user(user3.user.id) assert g_user1 == user1.user assert g_user2 == user2.user assert g_user3 == user3.user # get user by email: - by_email_user1 = await get_user_by_email("t1", "test@example.com") - by_email_user2 = await get_user_by_email("t2", "test@example.com") - by_email_user3 = await get_user_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) - assert by_email_user1 == user1.user - assert by_email_user2 == user2.user - assert by_email_user3 == user3.user + assert by_email_user1[0] == user1.user + assert by_email_user2[0] == user2.user + assert by_email_user3[0] == user3.user # create password reset token: - pless_reset_link1 = await create_reset_password_token("t1", user1.user.user_id) - pless_reset_link2 = await create_reset_password_token("t2", user2.user.user_id) - pless_reset_link3 = await create_reset_password_token("t3", user3.user.user_id) + pless_reset_link1 = await create_reset_password_token( + "t1", user1.user.id, user1.user.emails[0] + ) + pless_reset_link2 = await create_reset_password_token( + "t2", user2.user.id, user2.user.emails[0] + ) + pless_reset_link3 = await create_reset_password_token( + "t3", user3.user.id, user3.user.emails[0] + ) assert isinstance(pless_reset_link1, CreateResetPasswordOkResult) assert isinstance(pless_reset_link2, CreateResetPasswordOkResult) diff --git a/tests/emailpassword/test_passwordreset.py b/tests/emailpassword/test_passwordreset.py index f684b3290..26856b78b 100644 --- a/tests/emailpassword/test_passwordreset.py +++ b/tests/emailpassword/test_passwordreset.py @@ -18,6 +18,8 @@ from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId +from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, raises from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -26,7 +28,7 @@ from supertokens_python.recipe import emailpassword, session from supertokens_python.recipe.emailpassword.asyncio import create_reset_password_link from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordLinkUnknownUserIdError, + UnknownUserIdError, ) from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.asyncio import ( @@ -34,7 +36,6 @@ get_session, refresh_session, ) -from tests.testclient import TestClientWithNoCookieJar as TestClient from tests.utils import clean_st, reset, setup_st, sign_up_request, start_st @@ -57,7 +58,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") @@ -398,7 +399,7 @@ async def send_email( dict_response = json.loads(response_4.text) assert dict_response["status"] == "OK" assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio @@ -428,19 +429,20 @@ async def test_create_reset_password_link( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - link = await create_reset_password_link("public", user_info["id"]) - url = urlparse(link.link) # type: ignore + link = await create_reset_password_link("public", user_info["id"], "") + assert isinstance(link, str) + url = urlparse(link) queries = url.query.strip("&").split("&") assert url.path == "/auth/reset-password" assert "token=" in queries[0] assert "tenantId=public" in queries assert "rid=emailpassword" not in queries - link = await create_reset_password_link("public", "invalidUserId") - assert isinstance(link, CreateResetPasswordLinkUnknownUserIdError) + link = await create_reset_password_link("public", "invalidUserId", "") + assert isinstance(link, UnknownUserIdError) with raises(Exception) as err: - await create_reset_password_link("invalidTenantId", user_info["id"]) + await create_reset_password_link("invalidTenantId", user_info["id"], "") assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_signin.py b/tests/emailpassword/test_signin.py index 1bc192262..42f1b86d7 100644 --- a/tests/emailpassword/test_signin.py +++ b/tests/emailpassword/test_signin.py @@ -30,6 +30,7 @@ get_session, refresh_session, ) +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from tests.testclient import TestClientWithNoCookieJar as TestClient from tests.utils import ( @@ -61,7 +62,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/refresh") @@ -179,7 +180,7 @@ async def test_singinAPI_works_when_input_is_fine(driver_config_client: TestClie assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio @@ -224,7 +225,7 @@ async def test_singinAPI_works_when_input_is_fine_when_rid_is_tpep( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio @@ -269,7 +270,7 @@ async def test_singinAPI_works_when_input_is_fine_when_rid_is_emailpassword( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"] == user_info["emails"] @mark.asyncio diff --git a/tests/emailpassword/test_updateemailorpassword.py b/tests/emailpassword/test_updateemailorpassword.py index 586e7d5c4..71a53112b 100644 --- a/tests/emailpassword/test_updateemailorpassword.py +++ b/tests/emailpassword/test_updateemailorpassword.py @@ -15,6 +15,7 @@ from typing import Any from fastapi import FastAPI +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -73,7 +74,7 @@ async def test_update_email_or_password_with_default_validator( user_id = dict_response["user"]["id"] r = await update_email_or_password( - user_id=user_id, + recipe_user_id=RecipeUserId(user_id), email=None, password="test", user_context={}, @@ -126,7 +127,7 @@ async def validate_pass(value: Any, _tenant_id: str): user_id = dict_response["user"]["id"] r = await update_email_or_password( - user_id=user_id, + recipe_user_id=RecipeUserId(user_id), email=None, password="te", user_context={}, diff --git a/tests/frontendIntegration/django2x/polls/views.py b/tests/frontendIntegration/django2x/polls/views.py index b2d725449..eda19409a 100644 --- a/tests/frontendIntegration/django2x/polls/views.py +++ b/tests/frontendIntegration/django2x/polls/views.py @@ -48,6 +48,7 @@ merge_into_access_token_payload, ) from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -279,6 +280,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -290,6 +292,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -402,7 +405,7 @@ def login(request: HttpRequest): if request.method == "POST": user_id = json.loads(request.body)["userId"] - session_ = create_new_session(request, "public", user_id) + session_ = create_new_session(request, "public", RecipeUserId(user_id)) return HttpResponse(session_.get_user_id()) else: return send_options_api_response() diff --git a/tests/frontendIntegration/django3x/polls/views.py b/tests/frontendIntegration/django3x/polls/views.py index 6cfc59c76..1e182fb76 100644 --- a/tests/frontendIntegration/django3x/polls/views.py +++ b/tests/frontendIntegration/django3x/polls/views.py @@ -48,6 +48,7 @@ from supertokens_python.recipe.session.asyncio import merge_into_access_token_payload from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -282,6 +283,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -293,6 +295,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -405,7 +408,7 @@ async def login(request: HttpRequest): if request.method == "POST": user_id = json.loads(request.body)["userId"] - session_ = await create_new_session(request, "public", user_id) + session_ = await create_new_session(request, "public", RecipeUserId(user_id)) return HttpResponse(session_.get_user_id()) else: return send_options_api_response() diff --git a/tests/frontendIntegration/drf_async/mysite/settings.py b/tests/frontendIntegration/drf_async/mysite/settings.py index f30de3e1f..0ba39a229 100644 --- a/tests/frontendIntegration/drf_async/mysite/settings.py +++ b/tests/frontendIntegration/drf_async/mysite/settings.py @@ -16,6 +16,8 @@ # Build paths inside the project like this: os.path.join(BASE_DIR, ...) from corsheaders.defaults import default_headers +from supertokens_python.types import RecipeUserId + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Quick-start development settings - unsuitable for production @@ -173,6 +175,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -184,6 +187,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_async/polls/views.py b/tests/frontendIntegration/drf_async/polls/views.py index ddab05beb..e13848aef 100644 --- a/tests/frontendIntegration/drf_async/polls/views.py +++ b/tests/frontendIntegration/drf_async/polls/views.py @@ -54,6 +54,7 @@ from supertokens_python.recipe.session.asyncio import merge_into_access_token_payload from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -306,6 +307,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -317,6 +319,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -431,7 +434,7 @@ async def login(request: Request): # type: ignore if request.method == "POST": # type: ignore user_id = request.data["userId"] # type: ignore - session_ = await create_new_session(request, "public", user_id) # type: ignore + session_ = await create_new_session(request, "public", RecipeUserId(user_id)) # type: ignore return Response(session_.get_user_id()) # type: ignore else: return send_options_api_response() # type: ignore diff --git a/tests/frontendIntegration/drf_sync/mysite/settings.py b/tests/frontendIntegration/drf_sync/mysite/settings.py index 2f2c39b2a..bfcf33d19 100644 --- a/tests/frontendIntegration/drf_sync/mysite/settings.py +++ b/tests/frontendIntegration/drf_sync/mysite/settings.py @@ -16,6 +16,8 @@ # Build paths inside the project like this: os.path.join(BASE_DIR, ...) from corsheaders.defaults import default_headers +from supertokens_python.types import RecipeUserId + BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Quick-start development settings - unsuitable for production @@ -173,6 +175,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -184,6 +187,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/frontendIntegration/drf_sync/polls/views.py b/tests/frontendIntegration/drf_sync/polls/views.py index 148c7ce98..cd776104b 100644 --- a/tests/frontendIntegration/drf_sync/polls/views.py +++ b/tests/frontendIntegration/drf_sync/polls/views.py @@ -54,6 +54,7 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -306,6 +307,7 @@ def functions_override_session(param: RecipeInterface): def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -317,6 +319,7 @@ def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -431,7 +434,7 @@ def login(request: Request): # type: ignore if request.method == "POST": # type: ignore user_id = request.data["userId"] # type: ignore - session_ = create_new_session(request, "public", user_id) # type: ignore + session_ = create_new_session(request, "public", RecipeUserId(user_id)) # type: ignore return Response(session_.get_user_id()) # type: ignore else: return send_options_api_response() # type: ignore diff --git a/tests/frontendIntegration/fastapi-server/app.py b/tests/frontendIntegration/fastapi-server/app.py index 013e89b32..addaf27ab 100644 --- a/tests/frontendIntegration/fastapi-server/app.py +++ b/tests/frontendIntegration/fastapi-server/app.py @@ -52,6 +52,7 @@ RecipeInterface, ) from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.asyncio import get_session_information from supertokens_python.querier import Querier @@ -136,6 +137,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -147,6 +149,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -270,7 +273,7 @@ def login_options(): @app.post("/login") async def login(request: Request): user_id = (await request.json())["userId"] - _session = await create_new_session(request, "public", user_id) + _session = await create_new_session(request, "public", RecipeUserId(user_id)) return PlainTextResponse(content=_session.get_user_id()) diff --git a/tests/frontendIntegration/flask-server/app.py b/tests/frontendIntegration/flask-server/app.py index f99577eb6..e829b61fc 100644 --- a/tests/frontendIntegration/flask-server/app.py +++ b/tests/frontendIntegration/flask-server/app.py @@ -44,6 +44,7 @@ ) from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe from supertokens_python.constants import VERSION +from supertokens_python.types import RecipeUserId from supertokens_python.utils import is_version_gte from supertokens_python.recipe.session.syncio import get_session_information from supertokens_python.normalised_url_path import NormalisedURLPath @@ -162,6 +163,7 @@ def functions_override_session(param: RecipeInterface): async def create_new_session_custom( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[Dict[str, Any], None], session_data_in_database: Union[Dict[str, Any], None], disable_anti_csrf: Union[bool, None], @@ -173,6 +175,7 @@ async def create_new_session_custom( access_token_payload = {**access_token_payload, "customClaim": "customValue"} return await original_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -295,7 +298,7 @@ def login_options(): @app.route("/login", methods=["POST"]) # type: ignore def login(): user_id: str = request.get_json()["userId"] # type: ignore - _session = create_new_session(request, "public", user_id) + _session = create_new_session(request, "public", RecipeUserId(user_id)) return _session.get_user_id() diff --git a/tests/input_validation/test_input_validation.py b/tests/input_validation/test_input_validation.py index 08728031c..f66e22f22 100644 --- a/tests/input_validation/test_input_validation.py +++ b/tests/input_validation/test_input_validation.py @@ -17,6 +17,7 @@ from supertokens_python.recipe.emailverification.interfaces import ( GetEmailForUserIdOkResult, ) +from supertokens_python.types import RecipeUserId @pytest.mark.asyncio @@ -53,7 +54,7 @@ async def test_init_validation_emailpassword(): ) -async def get_email_for_user_id(_: str, __: Dict[str, Any]): +async def get_email_for_user_id(_: RecipeUserId, __: Dict[str, Any]): return GetEmailForUserIdOkResult("foo@example.com") @@ -89,7 +90,7 @@ async def test_init_validation_emailverification(): recipe_list=[ emailverification.init( mode="OPTIONAL", - get_email_for_user_id=get_email_for_user_id, + get_email_for_recipe_user_id=get_email_for_user_id, override="override", # type: ignore ) ], diff --git a/tests/jwt/test_get_JWKS.py b/tests/jwt/test_get_JWKS.py index 0cde0912c..20c945ab1 100644 --- a/tests/jwt/test_get_JWKS.py +++ b/tests/jwt/test_get_JWKS.py @@ -25,6 +25,7 @@ from supertokens_python.recipe import jwt from supertokens_python.recipe.jwt.interfaces import APIInterface, RecipeInterface from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.types import RecipeUserId from tests.utils import clean_st, reset, setup_st, start_st @@ -50,7 +51,7 @@ async def driver_config_client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} return TestClient(app) diff --git a/tests/multitenancy/test_tenants_crud.py b/tests/multitenancy/test_tenants_crud.py index 2fb10fe00..f048fc99a 100644 --- a/tests/multitenancy/test_tenants_crud.py +++ b/tests/multitenancy/test_tenants_crud.py @@ -17,8 +17,10 @@ from typing import Any, Dict from supertokens_python import init +from supertokens_python.asyncio import get_user from supertokens_python.framework.fastapi import get_middleware from supertokens_python.recipe import emailpassword, multitenancy, session +from supertokens_python.types import RecipeUserId from tests.utils import ( setup_function, teardown_function, @@ -40,11 +42,13 @@ create_or_update_third_party_config, delete_third_party_config, associate_user_to_tenant, - dissociate_user_from_tenant, + disassociate_user_from_tenant, ) -from supertokens_python.recipe.emailpassword.asyncio import sign_up, get_user_by_id +from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.thirdparty.provider import ( ProviderConfig, ProviderClientConfig, @@ -67,51 +71,72 @@ async def test_tenant_crud(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", + TenantConfigCreateOrUpdate( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) tenants = await list_all_tenants() assert len(tenants.tenants) == 4 t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is False - assert t1_config.third_party.enabled is False + assert t1_config.first_factors is not None + assert "emailpassword" in t1_config.first_factors + assert len(t1_config.first_factors) == 1 assert t1_config.core_config == {} t2_config = await get_tenant("t2") assert t2_config is not None - assert t2_config.emailpassword.enabled is False - assert t2_config.passwordless.enabled is True - assert t2_config.third_party.enabled is False + assert t2_config.first_factors is not None + assert "otp-email" in t2_config.first_factors + assert "otp-phone" in t2_config.first_factors + assert "link-email" in t2_config.first_factors + assert "link-phone" in t2_config.first_factors + assert len(t2_config.first_factors) == 4 assert t2_config.core_config == {} t3_config = await get_tenant("t3") assert t3_config is not None - assert t3_config.emailpassword.enabled is False - assert t3_config.passwordless.enabled is False - assert t3_config.third_party.enabled is True + assert t3_config.first_factors is not None + assert "thirdparty" in t3_config.first_factors + assert len(t3_config.first_factors) == 1 assert t3_config.core_config == {} # update tenant1 to add passwordless: - await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True)) + await create_or_update_tenant( + "t1", + TenantConfigCreateOrUpdate( + first_factors=[ + "otp-email", + ] + ), + ) t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is True - assert t1_config.third_party.enabled is False + assert t1_config.first_factors is not None + assert "otp-email" in t1_config.first_factors + assert len(t1_config.first_factors) == 1 assert t1_config.core_config == {} # update tenant1 to add thirdparty: - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty", "otp-email"]) + ) t1_config = await get_tenant("t1") assert t1_config is not None - assert t1_config is not None - assert t1_config.emailpassword.enabled is True - assert t1_config.passwordless.enabled is True - assert t1_config.third_party.enabled is True + assert t1_config.first_factors is not None + assert "otp-email" in t1_config.first_factors + assert "thirdparty" in t1_config.first_factors + assert len(t1_config.first_factors) == 2 assert t1_config.core_config == {} # delete tenant2: @@ -126,7 +151,9 @@ async def test_tenant_thirdparty_config(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) await create_or_update_third_party_config( "t1", config=ProviderConfig( @@ -139,8 +166,8 @@ async def test_tenant_thirdparty_config(): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 1 - provider = tenant_config.third_party.providers[0] + assert len(tenant_config.third_party_providers) == 1 + provider = tenant_config.third_party_providers[0] assert provider.third_party_id == "google" assert provider.clients is not None assert len(provider.clients) == 1 @@ -197,8 +224,8 @@ async def generate_fake_email(_: str, __: str, ___: Dict[str, Any]): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 1 - provider = tenant_config.third_party.providers[0] + assert len(tenant_config.third_party_providers) == 1 + provider = tenant_config.third_party_providers[0] assert provider.third_party_id == "google" assert provider.name == "Custom name" assert provider.clients is not None @@ -244,7 +271,7 @@ async def generate_fake_email(_: str, __: str, ___: Dict[str, Any]): tenant_config = await get_tenant("t1") assert tenant_config is not None - assert len(tenant_config.third_party.providers) == 0 + assert len(tenant_config.third_party_providers) == 0 async def test_user_association_and_disassociation_with_tenants(): @@ -253,26 +280,35 @@ async def test_user_association_and_disassociation_with_tenants(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", + TenantConfigCreateOrUpdate( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) signup_response = await sign_up("public", "test@example.com", "password1") assert isinstance(signup_response, SignUpOkResult) - user_id = signup_response.user.user_id + user_id = signup_response.user.id - await associate_user_to_tenant("t1", user_id) - await associate_user_to_tenant("t2", user_id) - await associate_user_to_tenant("t3", user_id) + await associate_user_to_tenant("t1", RecipeUserId(user_id)) + await associate_user_to_tenant("t2", RecipeUserId(user_id)) + await associate_user_to_tenant("t3", RecipeUserId(user_id)) - user = await get_user_by_id(user_id) + user = await get_user(user_id) assert user is not None assert len(user.tenant_ids) == 4 # public + 3 tenants - await dissociate_user_from_tenant("t1", user_id) - await dissociate_user_from_tenant("t2", user_id) - await dissociate_user_from_tenant("t3", user_id) + await disassociate_user_from_tenant("t1", RecipeUserId(user_id)) + await disassociate_user_from_tenant("t2", RecipeUserId(user_id)) + await disassociate_user_from_tenant("t3", RecipeUserId(user_id)) - user = await get_user_by_id(user_id) + user = await get_user(user_id) assert user is not None assert len(user.tenant_ids) == 1 # public only diff --git a/tests/passwordless/test_emaildelivery.py b/tests/passwordless/test_emaildelivery.py index ea8fa26ec..60f4bf8c9 100644 --- a/tests/passwordless/test_emaildelivery.py +++ b/tests/passwordless/test_emaildelivery.py @@ -59,7 +59,7 @@ create_email_verification_token, ) from supertokens_python.recipe.emailverification.interfaces import ( - CreateEmailVerificationTokenOkResult, + CreateEmailVerificationTokenEmailAlreadyVerifiedError, ) @@ -181,13 +181,14 @@ async def send_email_override( if not is_version_gte(version, "2.11"): return - pless_response = await signinup("public", "test@example.com", None, {}) + pless_response = await signinup("public", "test@example.com", None, None, {}) create_token = await create_email_verification_token( - "public", pless_response.user.user_id + "public", pless_response.recipe_user_id ) - assert isinstance(create_token, CreateEmailVerificationTokenOkResult) - # TODO: Replaced CreateEmailVerificationTokenEmailAlreadyVerifiedError. Confirm if this is correct. + assert isinstance( + create_token, CreateEmailVerificationTokenEmailAlreadyVerifiedError + ) assert ( all([outer_override_called, get_content_called, send_raw_email_called]) is False diff --git a/tests/passwordless/test_mutlitenancy.py b/tests/passwordless/test_mutlitenancy.py index 1121b11da..c1ff18bdf 100644 --- a/tests/passwordless/test_mutlitenancy.py +++ b/tests/passwordless/test_mutlitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user, list_users_by_account_info from supertokens_python.recipe import session, multitenancy, passwordless from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,11 +21,12 @@ from supertokens_python.recipe.passwordless.asyncio import ( create_code, consume_code, - get_user_by_id, - get_user_by_email, ConsumeCodeOkResult, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -57,9 +59,24 @@ async def test_multitenancy_functions(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(passwordless_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(passwordless_enabled=True)) + await create_or_update_tenant( + "t1", + TenantConfigCreateOrUpdate( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t2", + TenantConfigCreateOrUpdate( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) + await create_or_update_tenant( + "t3", + TenantConfigCreateOrUpdate( + first_factors=["otp-email", "otp-phone", "link-email", "link-phone"] + ), + ) code1 = await create_code( tenant_id="t1", email="test@example.com", user_input_code="123456" @@ -94,28 +111,34 @@ async def test_multitenancy_functions(): assert isinstance(user2, ConsumeCodeOkResult) assert isinstance(user3, ConsumeCodeOkResult) - assert user1.user.user_id != user2.user.user_id - assert user2.user.user_id != user3.user.user_id - assert user3.user.user_id != user1.user.user_id + assert user1.user.id != user2.user.id + assert user2.user.id != user3.user.id + assert user3.user.id != user1.user.id assert user1.user.tenant_ids == ["t1"] assert user2.user.tenant_ids == ["t2"] assert user3.user.tenant_ids == ["t3"] # get user by id: - g_user1 = await get_user_by_id(user1.user.user_id) - g_user2 = await get_user_by_id(user2.user.user_id) - g_user3 = await get_user_by_id(user3.user.user_id) + g_user1 = await get_user(user1.user.id) + g_user2 = await get_user(user2.user.id) + g_user3 = await get_user(user3.user.id) assert g_user1 == user1.user assert g_user2 == user2.user assert g_user3 == user3.user # get user by email: - by_email_user1 = await get_user_by_email("t1", "test@example.com") - by_email_user2 = await get_user_by_email("t2", "test@example.com") - by_email_user3 = await get_user_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) - assert by_email_user1 == user1.user - assert by_email_user2 == user2.user - assert by_email_user3 == user3.user + assert by_email_user1 == [user1.user] + assert by_email_user2 == [user2.user] + assert by_email_user3 == [user3.user] diff --git a/tests/sessions/claims/test_assert_claims.py b/tests/sessions/claims/test_assert_claims.py index f1f303884..c93718f30 100644 --- a/tests/sessions/claims/test_assert_claims.py +++ b/tests/sessions/claims/test_assert_claims.py @@ -16,6 +16,7 @@ ) from supertokens_python.recipe.session.session_class import Session from supertokens_python import init +from supertokens_python.types import RecipeUserId from tests.utils import ( get_st_init_args, setup_function, @@ -58,6 +59,7 @@ async def test_should_not_throw_for_empty_array(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, False, # access_token_updated @@ -96,6 +98,7 @@ async def test_should_call_validate_with_the_same_payload_object(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), payload, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -120,7 +123,9 @@ async def validate( def should_refetch(self, payload: JSONObject, user_context: Dict[str, Any]): return False - dummy_claim = PrimitiveClaim("st-claim", lambda _, __, ___: "Hello world") + dummy_claim = PrimitiveClaim( + "st-claim", lambda _, __, ___, ____, _____: "Hello world" + ) dummy_claim_validator = DummyClaimValidator(dummy_claim) @@ -142,5 +147,7 @@ async def test_assert_claims_should_work(): start_st() validator = TrueClaim.validators.is_true(1) - s = await create_new_session_without_request_response("public", "userid", {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userid"), {} + ) await s.assert_claims([validator]) diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index a6e5a2195..dbe8f76c0 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -4,6 +4,7 @@ from supertokens_python.framework import BaseRequest from supertokens_python.recipe import session from supertokens_python.recipe.session.asyncio import create_new_session +from supertokens_python.types import RecipeUserId from tests.utils import ( setup_function, teardown_function, @@ -28,10 +29,10 @@ async def test_create_access_token_payload_with_session_claims(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"v": True, "t": timestamp} @@ -41,10 +42,10 @@ async def test_should_create_access_token_payload_with_session_claims_with_an_no start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 9 + assert len(payload) == 10 assert payload.get("st-true") is None @@ -66,9 +67,9 @@ async def test_should_merge_claims_and_passed_access_token_payload_obj(timestamp start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 11 + assert len(payload) == 12 assert payload["st-true"] == {"v": True, "t": timestamp} assert payload["user-custom-claim"] == "foo" diff --git a/tests/sessions/claims/test_fetch_and_set_claim.py b/tests/sessions/claims/test_fetch_and_set_claim.py index 8e8d3ccad..66c0c2d50 100644 --- a/tests/sessions/claims/test_fetch_and_set_claim.py +++ b/tests/sessions/claims/test_fetch_and_set_claim.py @@ -3,6 +3,7 @@ from pytest import mark from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import NoneClaim, TrueClaim from tests.utils import AsyncMock, MagicMock @@ -28,6 +29,7 @@ async def test_should_not_change_if_claim_fetch_value_returns_none(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, False, # access_token_updated @@ -60,6 +62,7 @@ async def test_should_update_if_claim_fetch_value_returns_value(timestamp: int): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated diff --git a/tests/sessions/claims/test_get_claim_value.py b/tests/sessions/claims/test_get_claim_value.py index cea4c1afa..93d6db915 100644 --- a/tests/sessions/claims/test_get_claim_value.py +++ b/tests/sessions/claims/test_get_claim_value.py @@ -14,6 +14,7 @@ GetClaimValueOkResult, SessionDoesNotExistError, ) +from supertokens_python.types import RecipeUserId from tests.utils import setup_function, teardown_function, start_st, st_init_common_args from .utils import TrueClaim, get_st_init_args @@ -28,7 +29,7 @@ async def test_should_get_the_right_value(): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) res = await s.get_claim_value(TrueClaim) assert res is True @@ -39,7 +40,9 @@ async def test_should_get_the_right_value_using_session_handle(): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) res = await get_claim_value(s.get_handle(), TrueClaim) assert isinstance(res, GetClaimValueOkResult) diff --git a/tests/sessions/claims/test_primitive_array_claim.py b/tests/sessions/claims/test_primitive_array_claim.py index ac1469b56..ef2b8ef78 100644 --- a/tests/sessions/claims/test_primitive_array_claim.py +++ b/tests/sessions/claims/test_primitive_array_claim.py @@ -1,10 +1,11 @@ import math -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from unittest.mock import MagicMock from pytest import fixture, mark from pytest_mock import MockerFixture from supertokens_python.recipe.session.claims import PrimitiveArrayClaim +from supertokens_python.types import RecipeUserId from supertokens_python.utils import get_timestamp_ms, resolve from tests.utils import AsyncMock from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID @@ -57,31 +58,39 @@ def patch_get_timestamp_ms(pac_time_patch: Tuple[MockerFixture, int]): async def test_primitive_claim(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveArrayClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + ctx: Dict[str, Any] = {} + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, ctx) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} - await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + ctx: Dict[str, Any] = {} + recipe_user_id = RecipeUserId(user_id) + await claim.build(user_id, recipe_user_id, DEFAULT_TENANT_ID, {}, ctx) assert sync_fetch_value.call_count == 1 - assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ + assert ( + user_id, + recipe_user_id, + DEFAULT_TENANT_ID, + ctx, + {}, + ) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params @@ -91,8 +100,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveArrayClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -120,7 +129,9 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) assert claim.get_last_refetch_time(payload) == timestamp @@ -142,7 +153,9 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.includes(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -155,7 +168,9 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.includes(included_item).validate(payload, {}) assert res.is_valid is True @@ -163,7 +178,9 @@ async def test_validator_should_validate_matching_payload(): async def test_should_not_validate_old_values(patch_get_timestamp_ms: MagicMock): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -181,7 +198,9 @@ async def test_should_validate_old_values_if_max_age_is_none_and_default_is_inf( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -199,7 +218,9 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) assert ( await resolve( claim.validators.includes(excluded_item, 600).should_refetch(payload, {}) @@ -212,7 +233,9 @@ async def test_validator_should_refetch_if_value_is_old( patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -229,7 +252,9 @@ async def test_validator_should_not_refetch_if_max_age_is_none_and_default_is_in patch_get_timestamp_ms: MagicMock, ): claim = claim_with_inf_max_age - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -246,7 +271,9 @@ async def test_validator_should_validate_values_with_default_max_age( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 MINS: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -259,7 +286,9 @@ async def test_validator_should_not_refetch_if_max_age_overrides_to_inf( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 1 week patch_get_timestamp_ms.return_value += 7 * 24 * 60 * 60 * SECONDS # type: ignore @@ -292,7 +321,9 @@ async def test_validator_excludes_should_not_validate_empty_payload(): async def test_validator_excludes_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.excludes(included_item).validate(payload, {}) assert res.is_valid is False @@ -305,7 +336,9 @@ async def test_validator_excludes_should_not_validate_mismatching_payload(): async def test_validator_excludes_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.excludes(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -328,7 +361,9 @@ async def test_validator_includes_all_should_not_validate_empty_payload(): async def test_validator_includes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.includes_all(excluded_item).validate(payload, {}) assert res.is_valid is False @@ -341,7 +376,9 @@ async def test_validator_includes_all_should_not_validate_mismatching_payload(): async def test_validator_includes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.includes_all(included_item).validate(payload, {}) assert res.is_valid is True @@ -364,7 +401,9 @@ async def test_validator_excludes_all_should_not_validate_empty_payload(): async def test_validator_excludes_all_should_not_validate_mismatching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.excludes_all(included_item).validate(payload, {}) assert res.is_valid is False @@ -377,7 +416,9 @@ async def test_validator_excludes_all_should_not_validate_mismatching_payload(): async def test_validator_excludes_all_should_validate_matching_payload(): claim = PrimitiveArrayClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.excludes_all(excluded_item).validate(payload, {}) assert res.is_valid is True @@ -387,10 +428,12 @@ async def test_validator_should_not_validate_older_values_with_5min_default_max_ patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveArrayClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 MINS: - patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore + patch_get_timestamp_ms.return_value += 10 * MINS res = await resolve(claim.validators.includes(included_item).validate(payload, {})) assert res.is_valid is False diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 2ee8d98ba..bbfa2d815 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -2,6 +2,7 @@ from pytest import mark from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.types import RecipeUserId from supertokens_python.utils import resolve from tests.utils import AsyncMock @@ -24,31 +25,34 @@ def teardown_function(_): async def test_primitive_claim(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_without_async_fetch_value(timestamp: int): claim = PrimitiveClaim("key", async_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == {"key": {"t": timestamp, "v": val}} async def test_primitive_claim_matching__add_to_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - ctx = {} - res = await claim.build("user_id", "public", ctx) + res = await claim.build("user_id", RecipeUserId("user_id"), "public", {}, {}) assert res == claim.add_to_payload_({}, val, {}) async def test_primitive_claim_fetch_value_params_correct(): claim = PrimitiveClaim("key", sync_fetch_value) - user_id, ctx = "user_id", {} - await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert sync_fetch_value.call_count == 1 - assert (user_id, DEFAULT_TENANT_ID, ctx) == sync_fetch_value.call_args_list[0][ + assert ( + user_id, + RecipeUserId(user_id), + DEFAULT_TENANT_ID, + {}, + {}, + ) == sync_fetch_value.call_args_list[0][ 0 ] # extra [0] refers to call params @@ -58,8 +62,8 @@ async def test_primitive_claim_fetch_value_none(): fetch_value_none.return_value = None claim = PrimitiveClaim("key", fetch_value_none) - user_id, ctx = "user_id", {} - res = await claim.build(user_id, DEFAULT_TENANT_ID, ctx) + user_id = "user_id" + res = await claim.build(user_id, RecipeUserId(user_id), DEFAULT_TENANT_ID, {}, {}) assert res == {} @@ -89,7 +93,9 @@ async def test_get_last_refetch_time_empty_payload(): async def test_should_return_none_for_empty_payload(timestamp: int): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) assert claim.get_last_refetch_time(payload) == timestamp @@ -111,7 +117,9 @@ async def test_validators_should_not_validate_empty_payload(): async def test_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.has_value(val2).validate(payload, {}) assert res.is_valid is False @@ -124,7 +132,9 @@ async def test_should_not_validate_mismatching_payload(): async def test_validator_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.has_value(val).validate(payload, {}) assert res.is_valid is True @@ -132,7 +142,9 @@ async def test_validator_should_validate_matching_payload(): async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -150,7 +162,9 @@ async def test_should_refetch_if_value_not_set(): async def test_validator_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) assert ( await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) is False @@ -174,7 +188,9 @@ async def test_should_not_validate_empty_payload(): async def test_has_fresh_value_should_not_validate_mismatching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.has_value(val2, 600).validate(payload, {}) assert res.is_valid is False assert res.reason == { @@ -186,7 +202,9 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload(): async def test_should_validate_matching_payload(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) res = await claim.validators.has_value(val, 600).validate(payload, {}) assert res.is_valid is True @@ -196,7 +214,9 @@ async def test_should_not_validate_old_values_as_well( ): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -213,14 +233,16 @@ async def test_should_refetch_if_value_is_not_set(): async def test_should_not_refetch_if_value_is_set(): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) assert claim.validators.has_value(val2, 600).should_refetch(payload, {}) is False async def test_should_refetch_if_value_is_old(patch_get_timestamp_ms: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - payload = await claim.build("userId", DEFAULT_TENANT_ID) + payload = await claim.build( + "userId", RecipeUserId("userId"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -232,7 +254,9 @@ async def test_should_not_validate_old_values_as_well_with_default_max_age_provi patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("user_id", DEFAULT_TENANT_ID) + payload = await claim.build( + "user_id", RecipeUserId("user_id"), DEFAULT_TENANT_ID, {}, {} + ) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore @@ -250,7 +274,7 @@ async def test_should_refetch_if_value_is_old_with_default_max_age_provided( patch_get_timestamp_ms: MagicMock, ): claim = PrimitiveClaim("key", sync_fetch_value, 300) # 5 mins - payload = await claim.build("userId", "public") + payload = await claim.build("userId", RecipeUserId("userId"), "public", {}, {}) # Increase clock time by 10 mins: patch_get_timestamp_ms.return_value += 10 * MINS # type: ignore diff --git a/tests/sessions/claims/test_remove_claim.py b/tests/sessions/claims/test_remove_claim.py index b448af1ae..a954cacbd 100644 --- a/tests/sessions/claims/test_remove_claim.py +++ b/tests/sessions/claims/test_remove_claim.py @@ -11,6 +11,7 @@ remove_claim, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, get_st_init_args from tests.utils import AsyncMock, setup_function, start_st, teardown_function @@ -37,6 +38,7 @@ async def test_should_attempt_to_set_claim_to_none(): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -56,7 +58,9 @@ async def test_should_clear_previously_set_claim(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) payload = s.get_access_token_payload() @@ -68,7 +72,9 @@ async def test_should_clear_previously_set_claim_using_handle(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s: SessionContainer = await create_new_session(dummy_req, "public", "someId") + s: SessionContainer = await create_new_session( + dummy_req, "public", RecipeUserId("someId") + ) payload = s.get_access_token_payload() assert payload["st-true"] == {"v": True, "t": timestamp} diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 170cd66a4..a10677692 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -10,6 +10,7 @@ set_claim_value, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, get_st_init_args from tests.utils import AsyncMock, setup_function, start_st, teardown_function @@ -38,6 +39,7 @@ async def test_should_merge_the_right_value(timestamp: int): None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -57,17 +59,17 @@ async def test_should_overwrite_claim_value(timestamp: int): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": True} await s.set_claim_value(TrueClaim, False) # Payload should be updated now: payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": False} @@ -76,10 +78,10 @@ async def test_should_overwrite_claim_value_using_session_handle(timestamp: int) start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) payload = s.get_access_token_payload() - assert len(payload) == 10 + assert len(payload) == 11 assert payload["st-true"] == {"t": timestamp, "v": True} await set_claim_value(s.get_handle(), TrueClaim, False) diff --git a/tests/sessions/claims/test_validate_claims_for_session_handle.py b/tests/sessions/claims/test_validate_claims_for_session_handle.py index 72153ecee..ccad5b464 100644 --- a/tests/sessions/claims/test_validate_claims_for_session_handle.py +++ b/tests/sessions/claims/test_validate_claims_for_session_handle.py @@ -13,6 +13,7 @@ ClaimsValidationResult, SessionDoesNotExistError, ) +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import ( get_st_init_args, NoneClaim, @@ -31,7 +32,7 @@ async def test_should_return_the_right_validation_errors(): start_st() dummy_req: BaseRequest = MagicMock() - s = await create_new_session(dummy_req, "public", "someId") + s = await create_new_session(dummy_req, "public", RecipeUserId("someId")) failing_validator = NoneClaim.validators.has_value(True) res = await validate_claims_for_session_handle( @@ -40,7 +41,7 @@ async def test_should_return_the_right_validation_errors(): ) assert isinstance(res, ClaimsValidationResult) and len(res.invalid_claims) == 1 - assert res.invalid_claims[0].id == failing_validator.id + assert res.invalid_claims[0].id_ == failing_validator.id assert res.invalid_claims[0].reason == { "message": "value does not exist", "actualValue": None, diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index ae373e064..ff5ac669d 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -28,6 +28,7 @@ ClaimsValidationResult, ) from supertokens_python.recipe.session.session_class import Session +from supertokens_python.types import RecipeUserId from tests.sessions.claims.utils import TrueClaim, NoneClaim from tests.utils import ( setup_function, @@ -50,8 +51,9 @@ def st_init_generator_with_overriden_global_validators( ): def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_get_global_claim_validators( - _user_id: str, _tenant_id: str, + _user_id: str, + _recipe_user_id: RecipeUserId, _claim_validators_added_by_other_recipes: List[SessionClaimValidator], _user_context: Dict[str, Any], ): @@ -77,8 +79,9 @@ async def new_get_global_claim_validators( def st_init_generator_with_claim_validator(claim_validator: SessionClaimValidator): def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_get_global_claim_validators( - _user_id: str, _tenant_id: str, + _user_id: str, + _recipe_user_id: RecipeUserId, claim_validators_added_by_other_recipes: List[SessionClaimValidator], _user_context: Dict[str, Any], ): @@ -138,13 +141,13 @@ async def fastapi_client(): @app.post("/login") async def _login(request: Request): # type: ignore user_id = "userId" - await create_new_session(request, "public", user_id, {}, {}) + await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"userId": user_id} @app.post("/create-with-claim") async def _create_with_claim(request: Request): # type: ignore user_id = "userId" - _ = await create_new_session(request, "public", user_id, {}, {}) + _ = await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) key: str = (await request.json())["key"] # PrimitiveClaim(key, fetch_value="Value").add_to_session(session, "value") return {"userId": key} @@ -377,6 +380,7 @@ async def test_should_reject_if_assert_claims_returns_an_error( None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -426,6 +430,7 @@ async def test_should_allow_if_assert_claims_returns_no_error( None, # anti csrf token "test_session_handle", "test_user_id", + RecipeUserId("test_user_id"), {}, # user_data_in_access_token None, # req_res_info False, # access_token_updated @@ -442,7 +447,7 @@ async def test_should_allow_if_assert_claims_returns_no_error( assert validators == [validator] assert ctx["_default"]["request"] recipe_impl_mock.validate_claims.assert_called_once_with( # type: ignore - "test_user_id", {}, [validator], ctx + "test_user_id", RecipeUserId("test_user_id"), {}, [validator], ctx ) diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index fd6ae7d65..b66fd6d68 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -6,10 +6,11 @@ SessionClaim, ) from supertokens_python.recipe.session.interfaces import RecipeInterface +from supertokens_python.types import RecipeUserId from tests.utils import st_init_common_args -TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___: True) # type: ignore -NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___: None) # type: ignore +TrueClaim = BooleanClaim("st-true", fetch_value=lambda _, __, ___, _____, ______: True) +NoneClaim = BooleanClaim("st-none", fetch_value=lambda _, __, ___, _____, ______: None) def session_functions_override_with_claim( @@ -23,15 +24,22 @@ def session_function_override(oi: RecipeInterface) -> RecipeInterface: async def new_create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Union[None, Dict[str, Any]], session_data_in_database: Union[None, Dict[str, Any]], disable_anti_csrf: Optional[bool], tenant_id: str, user_context: Dict[str, Any], ): - payload_update = await claim.build(user_id, tenant_id, user_context) if access_token_payload is None: access_token_payload = {} + payload_update = await claim.build( + user_id, + RecipeUserId(user_id), + tenant_id, + access_token_payload, + user_context, + ) access_token_payload = { **access_token_payload, **payload_update, @@ -40,6 +48,7 @@ async def new_create_new_session( return await oi_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, diff --git a/tests/sessions/test_access_token_version.py b/tests/sessions/test_access_token_version.py index 55e0448be..8b532590a 100644 --- a/tests/sessions/test_access_token_version.py +++ b/tests/sessions/test_access_token_version.py @@ -17,6 +17,7 @@ validate_access_token_structure, ) from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.types import RecipeUserId from tests.utils import get_st_init_args, setup_function, start_st, teardown_function _ = setup_function # type:ignore @@ -30,7 +31,9 @@ async def test_access_token_v4(): start_st() access_token = ( - await create_new_session_without_request_response("public", "user-id") + await create_new_session_without_request_response( + "public", RecipeUserId("user-id") + ) ).get_access_token() s = await get_session_without_request_response(access_token) assert s is not None @@ -44,7 +47,7 @@ async def test_access_token_v4(): False, ) assert res["userId"] == "user-id" - assert parsed_info.version == 4 + assert parsed_info.version == 5 async def test_parsing_access_token_v2(): @@ -79,7 +82,9 @@ async def _create(request: Request): # type: ignore except Exception: pass - session = await create_new_session(request, "public", "userId", body, {}) + session = await create_new_session( + request, "public", RecipeUserId("userId"), body, {} + ) return {"message": True, "sessionHandle": session.get_handle()} @fast.get("/merge-into-payload") @@ -212,7 +217,7 @@ async def test_ignore_protected_props_in_create_session(): s = await create_new_session_without_request_response( "public", - "user1", + RecipeUserId("user1"), {"foo": "bar"}, ) payload = parse_jwt_without_signature_verification(s.access_token).payload @@ -220,7 +225,7 @@ async def test_ignore_protected_props_in_create_session(): assert payload["sub"] == "user1" s2 = await create_new_session_without_request_response( - "public", "user2", s.get_access_token_payload() + "public", RecipeUserId("user2"), s.get_access_token_payload() ) payload = parse_jwt_without_signature_verification(s2.access_token).payload assert payload["foo"] == "bar" diff --git a/tests/sessions/test_auth_mode.py b/tests/sessions/test_auth_mode.py index 0a01e0a0b..69b0b1b86 100644 --- a/tests/sessions/test_auth_mode.py +++ b/tests/sessions/test_auth_mode.py @@ -2,6 +2,7 @@ from typing_extensions import Literal from fastapi import Depends, FastAPI, Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark from supertokens_python import init @@ -33,7 +34,9 @@ async def app(): @fast.post("/create") async def _create(request: Request): # type: ignore body = await request.json() - session = await create_new_session(request, "public", "userId", body, {}) + session = await create_new_session( + request, "public", RecipeUserId("userId"), body, {} + ) return {"message": True, "sessionHandle": session.get_handle()} @fast.get("/update-payload") diff --git a/tests/sessions/test_jwks.py b/tests/sessions/test_jwks.py index 7affdbf31..7b800210d 100644 --- a/tests/sessions/test_jwks.py +++ b/tests/sessions/test_jwks.py @@ -15,6 +15,7 @@ get_session_without_request_response, ) from supertokens_python.recipe.session.recipe import SessionRecipe +from supertokens_python.types import RecipeUserId from supertokens_python.utils import get_timestamp_ms from tests.utils import ( get_st_init_args, @@ -85,7 +86,9 @@ async def test_that_jwks_is_fetched_as_expected(caplog: LogCaptureFixture): assert next(well_known_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) time.sleep(jwk_max_age_sec) tokens = s.get_all_session_tokens_dangerously() @@ -173,7 +176,9 @@ async def test_that_jwks_are_refresh_if_kid_is_unknown(caplog: LogCaptureFixture assert next(well_known_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert next(well_known_count) == 0 @@ -188,7 +193,9 @@ async def test_that_jwks_are_refresh_if_kid_is_unknown(caplog: LogCaptureFixture assert next(well_known_count) == 1 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert next(well_known_count) == 1 @@ -258,7 +265,9 @@ async def test_jwks_cache_logic(caplog: LogCaptureFixture): assert next(jwks_refresh_count) == 0 - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert get_cached_keys() is None assert next(jwks_refresh_count) == 0 @@ -383,7 +392,9 @@ async def test_that_jwks_returns_from_cache_correctly(caplog: LogCaptureFixture) init(**get_st_init_args(recipe_list=[session.init(jwks_refresh_interval_sec=2)])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) assert get_cached_keys() is None assert next(jwk_refresh_count) == 0 assert next(returned_from_cache_count) == 0 @@ -478,7 +489,9 @@ async def test_session_verification_of_jwt_based_on_session_payload( init(**get_st_init_args(recipe_list=[session.init()])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -500,7 +513,9 @@ async def test_session_verification_of_jwt_based_on_session_payload_with_check_d init(**get_st_init_args(recipe_list=[session.init()])) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -525,7 +540,9 @@ async def test_session_verification_of_jwt_with_dynamic_signing_key(): ) start_st() - s = await create_new_session_without_request_response("public", "userId", {}, {}) + s = await create_new_session_without_request_response( + "public", RecipeUserId("userId"), {}, {} + ) payload = s.get_access_token_payload() del payload["iat"] @@ -645,7 +662,7 @@ async def client(): @app.get("/login") async def login(request: Request): # type: ignore user_id = "test" - s = await create_new_session(request, "public", user_id, {}, {}) + s = await create_new_session(request, "public", RecipeUserId(user_id), {}, {}) return {"jwt": s.get_access_token()} @app.get("/sessioninfo") diff --git a/tests/sessions/test_session_error_handlers.py b/tests/sessions/test_session_error_handlers.py index 2dc052a97..0cb569297 100644 --- a/tests/sessions/test_session_error_handlers.py +++ b/tests/sessions/test_session_error_handlers.py @@ -17,6 +17,7 @@ from fastapi import FastAPI from fastapi.requests import Request from pytest import fixture, mark +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import init @@ -66,7 +67,7 @@ async def test_try_refresh(_request: Request): # type: ignore @app.post("/test/token-theft") async def test_token_theft(_request: Request): # type: ignore - raise TokenTheftError("", "") + raise TokenTheftError("", RecipeUserId(""), "") @app.post("/test/claim-validation") async def test_claim_validation(_request: Request): # type: ignore @@ -88,7 +89,11 @@ def unauthorised_f(_req: BaseRequest, _message: str, res: BaseResponse): return res def token_theft_f( - _req: BaseRequest, _session_handle: str, _user_id: str, res: BaseResponse + _req: BaseRequest, + _session_handle: str, + _user_id: str, + _rid: RecipeUserId, + res: BaseResponse, ): res.set_status_code(403) res.set_json_content({"message": "token theft detected from errorHandler"}) diff --git a/tests/sessions/test_use_dynamic_signing_key_switching.py b/tests/sessions/test_use_dynamic_signing_key_switching.py index 970a28df7..9dc3fc84b 100644 --- a/tests/sessions/test_use_dynamic_signing_key_switching.py +++ b/tests/sessions/test_use_dynamic_signing_key_switching.py @@ -7,6 +7,7 @@ refresh_session_without_request_response, ) from supertokens_python.recipe.session.session_class import SessionContainer +from supertokens_python.types import RecipeUserId from tests.utils import ( get_st_init_args, setup_function, @@ -31,7 +32,7 @@ async def test_dynamic_key_switching(): # Create a new session without an actual HTTP request-response flow create_res: SessionContainer = await create_new_session_without_request_response( - "public", "test-user-id", {"tokenProp": True}, {"dbProp": True} + "public", RecipeUserId("test-user-id"), {"tokenProp": True}, {"dbProp": True} ) # Extract session tokens for further testing @@ -67,7 +68,7 @@ async def test_refresh_session(): # Create a new session without an actual HTTP request-response flow create_res: SessionContainer = await create_new_session_without_request_response( - "public", "test-user-id", {"tokenProp": True}, {"dbProp": True} + "public", RecipeUserId("test-user-id"), {"tokenProp": True}, {"dbProp": True} ) # Extract session tokens for further testing diff --git a/tests/supertokens_python/test_supertokens_functions.py b/tests/supertokens_python/test_supertokens_functions.py index d33299a61..721db4a02 100644 --- a/tests/supertokens_python/test_supertokens_functions.py +++ b/tests/supertokens_python/test_supertokens_functions.py @@ -61,19 +61,19 @@ async def test_supertokens_functions(): for e in emails: signup_resp = await ep_asyncio.sign_up("public", e, "secret_pass") assert isinstance(signup_resp, SignUpOkResult) - user_ids.append(signup_resp.user.user_id) + user_ids.append(signup_resp.user.id) # Get user count assert await st_asyncio.get_user_count() == len(emails) # Get users in ascending order by joining time users_asc = (await st_asyncio.get_users_oldest_first("public", limit=10)).users - emails_asc = [user.email for user in users_asc] + emails_asc = [user.emails[0] for user in users_asc] assert emails_asc == emails # Get users in descending order by joining time users_desc = (await st_asyncio.get_users_newest_first("public", limit=10)).users - emails_desc = [user.email for user in users_desc] + emails_desc = [user.emails[0] for user in users_desc] assert emails_desc == emails[::-1] version = await Querier.get_instance().get_api_version() @@ -87,7 +87,7 @@ async def test_supertokens_functions(): # Again, get users in ascending order by joining time # We expect that the 2nd user (bar@example.com) must be absent. users_asc = (await st_asyncio.get_users_oldest_first("public", limit=10)).users - emails_asc = [user.email for user in users_asc] + emails_asc = [user.emails[0] for user in users_asc] assert emails[1] not in emails_asc # The 2nd user must be deleted now. if not is_version_gte(version, "2.20"): @@ -103,8 +103,8 @@ async def test_supertokens_functions(): "public", limit=10, query={"email": "baz"} ) ).users - emails_asc = [user.email for user in users_asc] - emails_desc = [user.email for user in users_desc] + emails_asc = [user.emails[0] for user in users_asc] + emails_desc = [user.emails[0] for user in users_desc] assert len(emails_asc) == 1 assert len(emails_desc) == 1 @@ -118,7 +118,7 @@ async def test_supertokens_functions(): "public", limit=10, query={"email": "john"} ) ).users - emails_asc = [user.email for user in users_asc] - emails_desc = [user.email for user in users_desc] + emails_asc = [user.emails[0] for user in users_asc] + emails_desc = [user.emails[0] for user in users_desc] assert len(emails_asc) == 0 assert len(emails_desc) == 0 diff --git a/tests/test-server/__init__.py b/tests/test-server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test-server/accountlinking.py b/tests/test-server/accountlinking.py new file mode 100644 index 000000000..30b73b694 --- /dev/null +++ b/tests/test-server/accountlinking.py @@ -0,0 +1,274 @@ +from flask import Flask, request, jsonify +from supertokens_python import async_to_sync_wrapper, convert_to_recipe_user_id +from supertokens_python.recipe.accountlinking.syncio import can_create_primary_user +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.accountlinking.syncio import is_sign_in_allowed +from supertokens_python.recipe.accountlinking.syncio import is_sign_up_allowed +from supertokens_python.recipe.accountlinking.syncio import ( + get_primary_user_that_can_be_linked_to_recipe_user_id, +) +from supertokens_python.recipe.accountlinking.syncio import ( + create_primary_user_id_or_link_accounts, +) +from supertokens_python.recipe.accountlinking.syncio import unlink_account +from supertokens_python.recipe.accountlinking.syncio import is_email_change_allowed +from supertokens_python.recipe.accountlinking.syncio import ( + link_accounts, + create_primary_user, +) +from supertokens_python.recipe.accountlinking.interfaces import ( + CanCreatePrimaryUserOkResult, + CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError, + CreatePrimaryUserOkResult, + CreatePrimaryUserRecipeUserIdAlreadyLinkedError, + LinkAccountsAccountInfoAlreadyAssociatedError, + LinkAccountsOkResult, + LinkAccountsRecipeUserIdAlreadyLinkedError, +) +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import User +from utils import serialize_user # pylint: disable=import-error +from session import convert_session_to_container + + +def add_accountlinking_routes(app: Flask): + @app.route("/test/accountlinking/createprimaryuser", methods=["POST"]) # type: ignore + def create_primary_user_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = create_primary_user(recipe_user_id, request.json.get("userContext")) + if isinstance(response, CreatePrimaryUserOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + elif isinstance(response, CreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + + @app.route("/test/accountlinking/linkaccounts", methods=["POST"]) # type: ignore + def link_accounts_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = link_accounts( + recipe_user_id, + request.json["primaryUserId"], + request.json.get("userContext"), + ) + if isinstance(response, LinkAccountsOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + "accountsAlreadyLinked": response.accounts_already_linked, + } + ) + elif isinstance(response, LinkAccountsRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, LinkAccountsAccountInfoAlreadyAssociatedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "status": response.status, + } + ) + + @app.route("/test/accountlinking/isemailchangeallowed", methods=["POST"]) # type: ignore + def is_email_change_allowed_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_email_change_allowed( + recipe_user_id, + request.json["newEmail"], + request.json["isVerified"], + session, + request.json.get("userContext"), + ) + return jsonify(response) + + @app.route("/test/accountlinking/unlinkaccount", methods=["POST"]) # type: ignore + def unlink_account_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = unlink_account( + recipe_user_id, + request.json.get("userContext"), + ) + return jsonify( + { + "status": response.status, + "wasRecipeUserDeleted": response.was_recipe_user_deleted, + "wasLinked": response.was_linked, + } + ) + + @app.route("/test/accountlinking/createprimaryuseridorlinkaccounts", methods=["POST"]) # type: ignore + def create_primary_user_id_or_link_accounts_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = create_primary_user_id_or_link_accounts( + request.json["tenantId"], + recipe_user_id, + session, + request.json.get("userContext", None), + ) + return jsonify(response.to_json()) + + @app.route("/test/accountlinking/getprimaryuserthatcanbelinkedtorecipeuserid", methods=["POST"]) # type: ignore + def get_primary_user_that_can_be_linked_to_recipe_user_id_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = get_primary_user_that_can_be_linked_to_recipe_user_id( + request.json["tenantId"], + recipe_user_id, + request.json.get("userContext", None), + ) + return jsonify(response.to_json() if response else None) + + @app.route("/test/accountlinking/issignupallowed", methods=["POST"]) # type: ignore + def is_signup_allowed_api(): # type: ignore + assert request.json is not None + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_sign_up_allowed( + request.json["tenantId"], + AccountInfoWithRecipeId( + recipe_id=request.json["newUser"]["recipeId"], + email=( + request.json["newUser"]["email"] + if "email" in request.json["newUser"] + else None + ), + phone_number=( + request.json["newUser"]["phoneNumber"] + if "phoneNumber" in request.json["newUser"] + else None + ), + third_party=( + ThirdPartyInfo( + third_party_user_id=request.json["newUser"]["thirdParty"]["id"], + third_party_id=request.json["newUser"]["thirdParty"][ + "thirdPartyId" + ], + ) + if "thirdParty" in request.json["newUser"] + else None + ), + ), + request.json["isVerified"], + session, + request.json.get("userContext", None), + ) + return jsonify(response) + + @app.route("/test/accountlinking/issigninallowed", methods=["POST"]) # type: ignore + def is_signin_allowed_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + session = None + if "session" in request.json: + session = convert_session_to_container(request) + response = is_sign_in_allowed( + request.json["tenantId"], + recipe_user_id, + session, + request.json.get("userContext", None), + ) + return jsonify(response) + + @app.route("/test/accountlinking/verifyemailforrecipeuseriflinkedaccountsareverified", methods=["POST"]) # type: ignore + def verify_email_for_recipe_user_if_linked_accounts_are_verified_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + user = User.from_json(request.json["user"]) + async_to_sync_wrapper.sync( + AccountLinkingRecipe.get_instance().verify_email_for_recipe_user_if_linked_accounts_are_verified( + user=user, + recipe_user_id=recipe_user_id, + user_context=request.json.get("userContext"), + ) + ) + return jsonify({}) + + @app.route("/test/accountlinking/cancreateprimaryuser", methods=["POST"]) # type: ignore + def can_create_primary_user_api(): # type: ignore + assert request.json is not None + recipe_user_id = convert_to_recipe_user_id(request.json["recipeUserId"]) + response = can_create_primary_user( + recipe_user_id, request.json.get("userContext") + ) + if isinstance(response, CanCreatePrimaryUserOkResult): + return jsonify( + { + "status": response.status, + "wasAlreadyAPrimaryUser": response.was_already_a_primary_user, + } + ) + elif isinstance(response, CanCreatePrimaryUserRecipeUserIdAlreadyLinkedError): + return jsonify( + { + "description": response.description, + "primaryUserId": response.primary_user_id, + "status": response.status, + } + ) + else: + return jsonify( + { + "description": response.description, + "status": response.status, + "primaryUserId": response.primary_user_id, + } + ) diff --git a/tests/test-server/app.py b/tests/test-server/app.py index f743a56fe..8ee45dd7e 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -1,6 +1,22 @@ +import inspect from typing import Any, Callable, Dict, List, Optional, TypeVar, Tuple from flask import Flask, request, jsonify -from utils import init_test_claims +from supertokens_python import process_state +from supertokens_python.framework import BaseRequest, BaseResponse +from supertokens_python.ingredients.emaildelivery.types import EmailDeliveryConfig +from supertokens_python.ingredients.smsdelivery.types import SMSDeliveryConfig +from supertokens_python.post_init_callbacks import PostSTInitCallbacks +from supertokens_python.recipe import ( + accountlinking, + dashboard, + multifactorauth, + passwordless, + totp, +) +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.totp.recipe import TOTPRecipe +from passwordless import add_passwordless_routes # pylint: disable=import-error from supertokens_python.process_state import ProcessState from supertokens_python.recipe.dashboard.recipe import DashboardRecipe from supertokens_python.recipe.emailpassword.recipe import EmailPasswordRecipe @@ -12,10 +28,21 @@ from supertokens_python.recipe.thirdparty.recipe import ThirdPartyRecipe from supertokens_python.recipe.usermetadata.recipe import UserMetadataRecipe from supertokens_python.recipe.userroles.recipe import UserRolesRecipe -from test_functions_mapper import get_func # type: ignore -from emailpassword import add_emailpassword_routes -from multitenancy import add_multitenancy_routes -from session import add_session_routes +from supertokens_python.types import RecipeUserId +from test_functions_mapper import ( # pylint: disable=import-error + get_func, + get_override_params, + reset_override_params, +) # pylint: disable=import-error +from totp import add_totp_routes # pylint: disable=import-error +from emailpassword import add_emailpassword_routes # pylint: disable=import-error +from thirdparty import add_thirdparty_routes # pylint: disable=import-error +from multitenancy import add_multitenancy_routes # pylint: disable=import-error +from accountlinking import add_accountlinking_routes # pylint: disable=import-error +from emailverification import ( + add_emailverification_routes, +) # pylint: disable=import-error +from session import add_session_routes # pylint: disable=import-error from supertokens_python import ( AppInfo, Supertokens, @@ -29,7 +56,7 @@ thirdparty, emailverification, ) -from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session import InputErrorHandlers, SessionContainer from supertokens_python.recipe.session.framework.flask import verify_session from supertokens_python.recipe.thirdparty.provider import UserFields, UserInfoMap from supertokens_python.recipe_module import RecipeModule @@ -46,11 +73,24 @@ def default_st_init(): + def origin_func( # pylint: disable=unused-argument, dangerous-default-value + request: Optional[BaseRequest] = None, + context: Dict[ # pylint: disable=unused-argument, dangerous-default-value + str, Any + ] = {}, # pylint: disable=unused-argument, dangerous-default-value + ) -> str: + if request is None: + return "http://localhost:8080" + origin = request.get_header("origin") + if origin is not None: + return origin + return "http://localhost:8080" + init( app_info=InputAppInfo( app_name="SuperTokens", api_domain="http://api.supertokens.io", - website_domain="http://localhost:3000", + origin=origin_func, ), supertokens_config=SupertokensConfig(connection_uri="http://localhost:3567"), framework="flask", @@ -63,54 +103,74 @@ def default_st_init(): def toCamelCase(snake_case: str) -> str: components = snake_case.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def create_override(input, member, name): # type: ignore - member_val = getattr(input, member) # type: ignore - - async def override(*args, **kwargs): # type: ignore - override_logging.log_override_event( - name + "." + toCamelCase(member), # type: ignore - "CALL", - {"args": args, "kwargs": kwargs}, - ) + res = components[0] + "".join(x.title() for x in components[1:]) + # Convert 'post', 'get', or 'put' at the end to uppercase + if res.endswith("Post"): + res = res[:-4] + "POST" + if res.endswith("Get"): + res = res[:-3] + "GET" + if res.endswith("Put"): + res = res[:-3] + "PUT" + return res + + +def create_override( + oI: Any, functionName: str, name: str, override_name: Optional[str] = None +): + implementation = oI if override_name is None else get_func(override_name)(oI) + originalFunction = getattr(implementation, functionName) + + async def finalFunction(*args: Any, **kwargs: Any): + if len(args) > 0: + override_logging.log_override_event( + name + "." + toCamelCase(functionName), + "CALL", + args, + ) + else: + override_logging.log_override_event( + name + "." + toCamelCase(functionName), "CALL", kwargs + ) try: - res = await member_val(*args, **kwargs) + if inspect.iscoroutinefunction(originalFunction): + res = await originalFunction(*args, **kwargs) + else: + res = originalFunction(*args, **kwargs) override_logging.log_override_event( - name + "." + toCamelCase(member), "RES", res # type: ignore + name + "." + toCamelCase(functionName), "RES", res ) return res except Exception as e: override_logging.log_override_event( - name + "." + toCamelCase(member), "REJ", e # type: ignore + name + "." + toCamelCase(functionName), "REJ", e ) raise e - setattr(input, member, override) # type: ignore + setattr(oI, functionName, finalFunction) def override_builder_with_logging( name: str, override_name: Optional[str] = None ) -> Callable[[T], T]: - def builder(input: T) -> T: - for member in dir(input): - if member.startswith("__"): - continue - - member_val = getattr(input, member) - if callable(member_val): - create_override(input, member, name) - return input + def builder(oI: T) -> T: + members = [ + attr + for attr in dir(oI) + if callable(getattr(oI, attr)) and not attr.startswith("__") + ] + for member in members: + create_override(oI, member, name, override_name) + return oI return builder def logging_override_func_sync(name: str, c: Any) -> Any: - def inner(*args, **kwargs): # type: ignore - override_logging.log_override_event( - name, "CALL", {"args": args, "kwargs": kwargs} - ) + def inner(*args: Any, **kwargs: Any) -> Any: + if len(args) > 0: + override_logging.log_override_event(name, "CALL", args) + else: + override_logging.log_override_event(name, "CALL", kwargs) try: res = c(*args, **kwargs) override_logging.log_override_event(name, "RES", res) @@ -119,10 +179,33 @@ def inner(*args, **kwargs): # type: ignore override_logging.log_override_event(name, "REJ", e) raise e - return inner # type: ignore + return inner + + +def callback_with_log( + name: str, override_name: Optional[str], default_value: Any = None +) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + if override_name: + impl = get_func(override_name) + else: + + async def default_func( # pylint: disable=unused-argument + *args: Any, **kwargs: Any # pylint: disable=unused-argument + ) -> Any: # pylint: disable=unused-argument + return default_value + + impl = default_func + + return logging_override_func_sync(name, impl)(*args, **kwargs) + + return wrapper def st_reset(): + PostSTInitCallbacks.reset() + override_logging.reset_override_logs() + reset_override_params() ProcessState.get_instance().reset() Supertokens.reset() SessionRecipe.reset() @@ -136,18 +219,23 @@ def st_reset(): DashboardRecipe.reset() PasswordlessRecipe.reset() MultitenancyRecipe.reset() + AccountLinkingRecipe.reset() + TOTPRecipe.reset() + MultiFactorAuthRecipe.reset() -def init_st(config): # type: ignore +def init_st(config: Dict[str, Any]): st_reset() override_logging.reset_override_logs() - recipe_list: List[Callable[[AppInfo], RecipeModule]] = [] - for recipe_config in config.get("recipeList", []): # type: ignore - recipe_id = recipe_config.get("recipeId") # type: ignore + recipe_list: List[Callable[[AppInfo], RecipeModule]] = [ + dashboard.init(api_key="test") + ] + for recipe_config in config.get("recipeList", []): + recipe_id = recipe_config.get("recipeId") if recipe_id == "emailpassword": sign_up_feature_input = None - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) if "signUpFeature" in recipe_config_json: sign_up_feature = recipe_config_json["signUpFeature"] if "formFields" in sign_up_feature: @@ -159,11 +247,41 @@ def init_st(config): # type: ignore ) recipe_list.append( - emailpassword.init(sign_up_feature=sign_up_feature_input) + emailpassword.init( + sign_up_feature=sign_up_feature_input, + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "EmailPassword.emailDelivery.override", + recipe_config_json.get("emailDelivery", {}).get( + "override", None + ), + ) + ), + override=emailpassword.InputOverrideConfig( + apis=override_builder_with_logging( + "EmailPassword.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "EmailPassword.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), + ) ) elif recipe_id == "session": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + + async def custom_unauthorised_callback( + _: BaseRequest, __: str, response: BaseResponse + ) -> BaseResponse: + response.set_status_code(401) + response.set_json_content(content={"type": "UNAUTHORISED"}) + return response + + recipe_config_json = json.loads(recipe_config.get("config", "{}")) recipe_list.append( session.init( cookie_secure=recipe_config_json.get("cookieSecure"), @@ -183,11 +301,48 @@ def init_st(config): # type: ignore use_dynamic_access_token_signing_key=recipe_config_json.get( "useDynamicAccessTokenSigningKey" ), + override=session.InputOverrideConfig( + apis=override_builder_with_logging( + "Session.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "Session.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), + error_handlers=InputErrorHandlers( + on_unauthorised=custom_unauthorised_callback + ), + ) + ) + elif recipe_id == "accountlinking": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + accountlinking.init( + should_do_automatic_account_linking=callback_with_log( + "AccountLinking.shouldDoAutomaticAccountLinking", + recipe_config_json.get("shouldDoAutomaticAccountLinking"), + accountlinking.ShouldNotAutomaticallyLink(), + ), + on_account_linked=callback_with_log( + "AccountLinking.onAccountLinked", + recipe_config_json.get("onAccountLinked"), + ), + override=accountlinking.InputOverrideConfig( + functions=override_builder_with_logging( + "AccountLinking.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), ) ) - elif recipe_id == "thirdparty": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) providers: List[thirdparty.ProviderInput] = [] if "signInAndUpFeature" in recipe_config_json: sign_in_up_feature = recipe_config_json["signInAndUpFeature"] @@ -217,7 +372,7 @@ def init_st(config): # type: ignore ), ) - include_in_non_public_tenants_by_default = None + include_in_non_public_tenants_by_default = False if "includeInNonPublicTenantsByDefault" in provider: include_in_non_public_tenants_by_default = provider[ @@ -268,32 +423,180 @@ def init_st(config): # type: ignore ), ), include_in_non_public_tenants_by_default=include_in_non_public_tenants_by_default, + override=override_builder_with_logging( + "ThirdParty.providers.override", + provider.get("override", None), + ), ) providers.append(provider_input) recipe_list.append( thirdparty.init( sign_in_and_up_feature=thirdparty.SignInAndUpFeature( providers=providers - ) + ), + override=thirdparty.InputOverrideConfig( + functions=override_builder_with_logging( + "ThirdParty.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + apis=override_builder_with_logging( + "ThirdParty.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + ), ) ) elif recipe_id == "emailverification": - recipe_config_json = json.loads(recipe_config.get("config", "{}")) # type: ignore - ev_config: Dict[str, Any] = {"mode": "OPTIONAL"} - if "mode" in recipe_config_json: - ev_config["mode"] = recipe_config_json["mode"] + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + + from supertokens_python.recipe.emailverification.utils import ( + OverrideConfig as EmailVerificationOverrideConfig, + ) + from supertokens_python.recipe.emailverification.interfaces import ( + UnknownUserIdError, + ) + + recipe_list.append( + emailverification.init( + mode=( + recipe_config_json["mode"] + if "mode" in recipe_config_json + else "OPTIONAL" + ), + override=EmailVerificationOverrideConfig( + apis=override_builder_with_logging( + "EmailVerification.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + functions=override_builder_with_logging( + "EmailVerification.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + ), + get_email_for_recipe_user_id=callback_with_log( + "EmailVerification.getEmailForRecipeUserId", + recipe_config_json.get("getEmailForRecipeUserId"), + UnknownUserIdError(), + ), + email_delivery=EmailDeliveryConfig( + override=override_builder_with_logging( + "EmailVerification.emailDelivery.override", + recipe_config_json.get("emailDelivery", {}).get( + "override", None + ), + ) + ), + ) + ) + elif recipe_id == "multifactorauth": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + multifactorauth.init( + first_factors=recipe_config_json.get("firstFactors", None), + override=multifactorauth.OverrideConfig( + functions=override_builder_with_logging( + "MultifactorAuth.override.functions", + recipe_config_json.get("override", {}).get( + "functions", None + ), + ), + apis=override_builder_with_logging( + "MultifactorAuth.override.apis", + recipe_config_json.get("override", {}).get("apis", None), + ), + ), + ) + ) + elif recipe_id == "passwordless": + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + contact_config: passwordless.ContactConfig = ( + passwordless.ContactEmailOnlyConfig() + ) + if recipe_config_json.get("contactMethod") == "PHONE": + contact_config = passwordless.ContactPhoneOnlyConfig() + elif recipe_config_json.get("contactMethod") == "EMAIL_OR_PHONE": + contact_config = passwordless.ContactEmailOrPhoneConfig() + + class EmailDeliveryCustom(passwordless.EmailDeliveryInterface[Any]): + async def send_email( + self, template_vars: Any, user_context: Dict[str, Any] + ) -> None: + f = get_func("passwordless.init.emailDelivery.service.sendEmail") + return f(template_vars, user_context) + + class SMSDeliveryCustom(passwordless.SMSDeliveryInterface[Any]): + async def send_sms( + self, template_vars: Any, user_context: Dict[str, Any] + ) -> None: + f = get_func("passwordless.init.smsDelivery.service.sendSms") + return f(template_vars, user_context) - override_functions = override_builder_with_logging("EmailVerification.override.functions") # type: ignore + recipe_list.append( + passwordless.init( + email_delivery=EmailDeliveryConfig( + service=EmailDeliveryCustom(), + override=override_builder_with_logging( + "Passwordless.emailDelivery.override", + config.get("emailDelivery", {}).get("override", None), + ), + ), + sms_delivery=SMSDeliveryConfig( + service=SMSDeliveryCustom(), + override=override_builder_with_logging( + "Passwordless.smsDelivery.override", + config.get("smsDelivery", {}).get("override", None), + ), + ), + contact_config=contact_config, + flow_type=recipe_config_json.get("flowType"), + override=passwordless.InputOverrideConfig( + apis=override_builder_with_logging( + "Passwordless.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "Passwordless.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ), + ) + ) + elif recipe_id == "totp": + from supertokens_python.recipe.totp.types import ( + OverrideConfig as TOTPOverrideConfig, + ) - ev_config["override"] = emailverification.InputOverrideConfig( - functions=override_functions # type: ignore + recipe_config_json = json.loads(recipe_config.get("config", "{}")) + recipe_list.append( + totp.init( + config=totp.TOTPConfig( + default_period=recipe_config_json.get("defaultPeriod"), + default_skew=recipe_config_json.get("defaultSkew"), + issuer=recipe_config_json.get("issuer"), + override=TOTPOverrideConfig( + apis=override_builder_with_logging( + "Multitenancy.override.apis", + recipe_config_json.get("override", {}).get("apis"), + ), + functions=override_builder_with_logging( + "Multitenancy.override.functions", + recipe_config_json.get("override", {}).get("functions"), + ), + ), + ) + ) ) - recipe_list.append(emailverification.init(**ev_config)) - interceptor_func = None # type: ignore - if config.get("supertokens", {}).get("networkInterceptor") is not None: # type: ignore - interceptor_func = get_func(config.get("supertokens", {}).get("networkInterceptor")) # type: ignore + interceptor_func = None + if config.get("supertokens", {}).get("networkInterceptor") is not None: + interceptor_func = get_func( + config.get("supertokens", {}).get("networkInterceptor") + ) def network_interceptor_func( url: str, @@ -313,8 +616,26 @@ def inner( body: Optional[Dict[str, Any]], user_context: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: + # print( + # "-------------------------------------------!!!!!!!!!!!!!!!!!!!!!!!!!!!" + # ) + # print(url) + # import traceback + # print("Stack trace:") + # traceback.print_stack() + if interceptor_func is not None: - return interceptor_func(url, method, headers, params, body, user_context) # type: ignore + resp = interceptor_func( + url, method, headers, params, body, user_context + ) + return { + "url": resp[0], + "method": resp[1], + "headers": resp[2], + "params": resp[3], + "body": resp[4], + "user_context": resp[5], + } return { "url": url, "method": method, @@ -324,7 +645,9 @@ def inner( "user_context": user_context, } - res = logging_override_func_sync("networkInterceptor", inner)(url, method, headers, params, body, user_context) # type: ignore + res = logging_override_func_sync("networkInterceptor", inner)( + url, method, headers, params, body, user_context + ) return ( res.get("url"), res.get("method"), @@ -333,19 +656,24 @@ def inner( res.get("body"), ) - init( - app_info=InputAppInfo( - app_name=config["appInfo"]["appName"], # type: ignore - api_domain=config["appInfo"]["apiDomain"], # type: ignore - website_domain=config["appInfo"]["websiteDomain"], # type: ignore - ), - supertokens_config=SupertokensConfig( - connection_uri=config["supertokens"]["connectionURI"], # type: ignore - network_interceptor=network_interceptor_func, - ), - framework="flask", - recipe_list=recipe_list, - ) + try: + init( + app_info=InputAppInfo( + app_name=config["appInfo"]["appName"], + api_domain=config["appInfo"]["apiDomain"], + website_domain=config["appInfo"]["websiteDomain"], + ), + supertokens_config=SupertokensConfig( + connection_uri=config["supertokens"]["connectionURI"], + network_interceptor=network_interceptor_func, + ), + framework="flask", + recipe_list=recipe_list, + ) + except Exception as e: + st_reset() + default_st_init() + raise e # Routes @@ -356,7 +684,9 @@ def ping(): @app.route("/test/init", methods=["POST"]) # type: ignore def init_handler(): - config = request.json.get("config") # type: ignore + if request.json is None: + return jsonify({"error": "No config provided"}), 400 + config = request.json.get("config") if config: init_st(json.loads(config)) return jsonify({"ok": True}) @@ -365,17 +695,18 @@ def init_handler(): @app.route("/test/overrideparams", methods=["GET"]) # type: ignore def override_params(): - return jsonify("TODO") + return jsonify(get_override_params().to_json()) @app.route("/test/featureflag", methods=["GET"]) # type: ignore def feature_flag(): - return jsonify([]) + return jsonify(["removedOverwriteSessionDuringSignInUp"]) @app.route("/test/resetoverrideparams", methods=["POST"]) # type: ignore -def reset_override_params(): +def reset_override_params_api(): override_logging.reset_override_logs() + reset_override_params() return jsonify({"ok": True}) @@ -389,29 +720,45 @@ def mock_external_api(): return jsonify({"ok": True}) -# @app.route("/create", methods=["POST"]) # type: ignore -# def create_session(): -# recipe_user_id = request.json.get("recipeUserId") # type: ignore +@app.route("/create", methods=["POST"]) # type: ignore +def create_session_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + recipe_user_id = RecipeUserId(data.get("recipeUserId")) + + from supertokens_python.recipe.session.syncio import create_new_session -# session = session.create_new_session(request, "public", recipe_user_id) -# return jsonify({"status": "OK"}) + create_new_session(request, "public", recipe_user_id) + return jsonify({"status": "OK"}) @app.route("/getsession", methods=["POST"]) # type: ignore @verify_session() def get_session(): - session: SessionContainer = request.environ["session"] + from supertokens_python.recipe.session.syncio import get_session + + session = get_session(request) + assert session is not None return jsonify( - {"userId": session.get_user_id(), "recipeUserId": session.get_user_id()} + { + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + } ) -# @app.route("/refreshsession", methods=["POST"]) # type: ignore -# def refresh_session(): -# session: SessionContainer = session.refresh_session(request) -# return jsonify( -# {"userId": session.get_user_id(), "recipeUserId": session.get_user_id()} -# ) +@app.route("/refreshsession", methods=["POST"]) # type: ignore +def refresh_session_api(): # type: ignore + from supertokens_python.recipe.session.syncio import refresh_session + + session: SessionContainer = refresh_session(request) + return jsonify( + { + "userId": session.get_user_id(), + "recipeUserId": session.get_recipe_user_id().get_as_string(), + } + ) @app.route("/verify", methods=["GET"]) # type: ignore @@ -420,16 +767,58 @@ def verify_session_route(): return jsonify({"status": "OK"}) +@app.route("/test/waitforevent", methods=["GET"]) # type: ignore +def wait_for_event_api(): # type: ignore + event = request.args.get("event") + if not event: + raise ValueError("event query param missing") + + event_enum = process_state.PROCESS_STATE(int(event)) + instance = process_state.ProcessState.get_instance() + event_result = instance.wait_for_event(event_enum) + if event_result is None: + return jsonify(None) + else: + return jsonify("Found") + + @app.errorhandler(404) -def not_found(error): # type: ignore +def not_found(error: Any) -> Any: # pylint: disable=unused-argument return jsonify({"error": f"Route not found: {request.method} {request.path}"}), 404 +import traceback +from flask import jsonify + + +@app.errorhandler(Exception) # type: ignore +def handle_exception(e: Exception): + # Print the error and stack trace + print(f"An error occurred: {str(e)}") + traceback.print_exc() + + # Return JSON response with 500 status code + return jsonify({"error": "Internal Server Error", "message": str(e)}), 500 + + add_emailpassword_routes(app) add_multitenancy_routes(app) add_session_routes(app) +add_emailverification_routes(app) +add_thirdparty_routes(app) +add_accountlinking_routes(app) +add_passwordless_routes(app) +add_totp_routes(app) +from supertokens import add_supertokens_routes # pylint: disable=import-error + +add_supertokens_routes(app) +from usermetadata import add_usermetadata_routes + +add_usermetadata_routes(app) + +from multifactorauth import add_multifactorauth_routes -init_test_claims() +add_multifactorauth_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/emailpassword.py b/tests/test-server/emailpassword.py index efde717b5..553751853 100644 --- a/tests/test-server/emailpassword.py +++ b/tests/test-server/emailpassword.py @@ -1,13 +1,20 @@ from flask import Flask, request, jsonify from supertokens_python.recipe.emailpassword.interfaces import ( - CreateResetPasswordLinkOkResult, + EmailAlreadyExistsError, SignInOkResult, SignUpOkResult, - UpdateEmailOrPasswordEmailAlreadyExistsError, + UnknownUserIdError, + UpdateEmailOrPasswordEmailChangeNotAllowedError, UpdateEmailOrPasswordOkResult, - UpdateEmailOrPasswordUnknownUserIdError, + WrongCredentialsError, ) import supertokens_python.recipe.emailpassword.syncio as emailpassword +from session import convert_session_to_container # pylint: disable=import-error +from supertokens_python.types import RecipeUserId +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error def add_emailpassword_routes(app: Flask): @@ -20,23 +27,33 @@ def emailpassword_signup(): # type: ignore email = data["email"] password = data["password"] user_context = data.get("userContext") + session = convert_session_to_container(data) if "session" in data else None - response = emailpassword.sign_up(tenant_id, email, password, user_context) + response = emailpassword.sign_up( + tenant_id, email, password, session, user_context + ) if isinstance(response, SignUpOkResult): return jsonify( { "status": "OK", - "user": { - "id": response.user.user_id, - "email": response.user.email, - "timeJoined": response.user.time_joined, - "tenantIds": response.user.tenant_ids, - }, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), } ) - else: + elif isinstance(response, EmailAlreadyExistsError): return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) @app.route("/test/emailpassword/signin", methods=["POST"]) # type: ignore def emailpassword_signin(): # type: ignore @@ -48,23 +65,33 @@ def emailpassword_signin(): # type: ignore email = data["email"] password = data["password"] user_context = data.get("userContext") + session = convert_session_to_container(data) if "session" in data else None - response = emailpassword.sign_in(tenant_id, email, password, user_context) + response = emailpassword.sign_in( + tenant_id, email, password, session, user_context + ) if isinstance(response, SignInOkResult): return jsonify( { "status": "OK", - "user": { - "id": response.user.user_id, - "email": response.user.email, - "timeJoined": response.user.time_joined, - "tenantIds": response.user.tenant_ids, - }, + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), } ) - else: + elif isinstance(response, WrongCredentialsError): return jsonify({"status": "WRONG_CREDENTIALS_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) @app.route("/test/emailpassword/createresetpasswordlink", methods=["POST"]) # type: ignore def emailpassword_create_reset_password_link(): # type: ignore @@ -80,8 +107,8 @@ def emailpassword_create_reset_password_link(): # type: ignore tenant_id, user_id, user_context ) - if isinstance(response, CreateResetPasswordLinkOkResult): - return jsonify({"status": "OK", "link": response.link}) + if isinstance(response, str): + return jsonify({"status": "OK", "link": response}) else: return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) @@ -91,7 +118,7 @@ def emailpassword_update_email_or_password(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - user_id = data["userId"] + recipe_user_id = RecipeUserId(data["recipeUserId"]) email = data.get("email") password = data.get("password") apply_password_policy = data.get("applyPasswordPolicy") @@ -99,7 +126,7 @@ def emailpassword_update_email_or_password(): # type: ignore user_context = data.get("userContext") response = emailpassword.update_email_or_password( - user_id, + recipe_user_id, email, password, apply_password_policy, @@ -109,10 +136,14 @@ def emailpassword_update_email_or_password(): # type: ignore if isinstance(response, UpdateEmailOrPasswordOkResult): return jsonify({"status": "OK"}) - elif isinstance(response, UpdateEmailOrPasswordUnknownUserIdError): + elif isinstance(response, UnknownUserIdError): return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) - elif isinstance(response, UpdateEmailOrPasswordEmailAlreadyExistsError): + elif isinstance(response, EmailAlreadyExistsError): return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, UpdateEmailOrPasswordEmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) else: return jsonify( { diff --git a/tests/test-server/emailverification.py b/tests/test-server/emailverification.py new file mode 100644 index 000000000..e47660f18 --- /dev/null +++ b/tests/test-server/emailverification.py @@ -0,0 +1,129 @@ +from flask import Flask, request, jsonify + +from supertokens_python import async_to_sync_wrapper +from supertokens_python.framework.flask.flask_request import FlaskRequest +from supertokens_python.recipe.emailverification.interfaces import ( + CreateEmailVerificationTokenOkResult, + VerifyEmailUsingTokenOkResult, +) +from supertokens_python.recipe.emailverification.syncio import ( + create_email_verification_token, +) + + +def add_emailverification_routes(app: Flask): + @app.route("/test/emailverification/isemailverified", methods=["POST"]) # type: ignore + def is_email_verified_api(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + from supertokens_python.recipe.emailverification.syncio import is_email_verified + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) + email = data.get("email") + user_context = data.get("userContext", {}) + + response = is_email_verified(recipe_user_id, email, user_context) + return jsonify(response) + + @app.route("/test/emailverification/createemailverificationtoken", methods=["POST"]) # type: ignore + def f(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) + tenant_id = data.get("tenantId", "public") + email = None if "email" not in data else data["email"] + user_context = data.get("userContext") + + response = create_email_verification_token( + tenant_id, recipe_user_id, email, user_context + ) + + if isinstance(response, CreateEmailVerificationTokenOkResult): + return jsonify({"status": "OK", "token": response.token}) + else: + return jsonify({"status": "EMAIL_ALREADY_VERIFIED_ERROR"}) + + @app.route("/test/emailverification/verifyemailusingtoken", methods=["POST"]) # type: ignore + def f2(): # type: ignore + from supertokens_python.recipe.emailverification.syncio import ( + verify_email_using_token, + ) + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + token = data["token"] + attempt_account_linking = data.get("attemptAccountLinking", True) + user_context = data.get("userContext", {}) + + response = verify_email_using_token( + tenant_id, token, attempt_account_linking, user_context + ) + + if isinstance(response, VerifyEmailUsingTokenOkResult): + return jsonify( + { + "status": "OK", + "user": { + "email": response.user.email, + "recipeUserId": { + # this is intentionally done this way cause the test in the test suite expects this way. + "recipeUserId": response.user.recipe_user_id.get_as_string() + }, + }, + } + ) + else: + return jsonify({"status": "EMAIL_VERIFICATION_INVALID_TOKEN_ERROR"}) + + @app.route("/test/emailverification/unverifyemail", methods=["POST"]) # type: ignore + def unverify_email(): # type: ignore + from supertokens_python.recipe.emailverification.syncio import unverify_email + from supertokens_python.types import RecipeUserId + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id = RecipeUserId(data["recipeUserId"]) + email = data.get("email") + user_context = data.get("userContext", {}) + + unverify_email(recipe_user_id, email, user_context) + return jsonify({"status": "OK"}) + + @app.route("/test/emailverification/updatesessionifrequiredpostemailverification", methods=["POST"]) # type: ignore + def update_session_if_required_post_email_verification(): # type: ignore + from supertokens_python.recipe.emailverification import EmailVerificationRecipe + from supertokens_python.types import RecipeUserId + from session import convert_session_to_container, convert_session_to_json + + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + recipe_user_id_whose_email_got_verified = RecipeUserId( + data["recipeUserIdWhoseEmailGotVerified"]["recipeUserId"] + ) + session = convert_session_to_container(data) if "session" in data else None + + session_resp = async_to_sync_wrapper.sync( + EmailVerificationRecipe.get_instance_or_throw().update_session_if_required_post_email_verification( + recipe_user_id_whose_email_got_verified=recipe_user_id_whose_email_got_verified, + session=session, + req=FlaskRequest(request), + user_context=data.get("userContext", {}), + ) + ) + return jsonify( + None if session_resp is None else convert_session_to_json(session_resp) + ) diff --git a/tests/test-server/multifactorauth.py b/tests/test-server/multifactorauth.py new file mode 100644 index 000000000..d82048917 --- /dev/null +++ b/tests/test-server/multifactorauth.py @@ -0,0 +1,200 @@ +from typing import List +from flask import Flask, request, jsonify + +from supertokens_python import async_to_sync_wrapper +from supertokens_python.recipe.multifactorauth.types import MFAClaimValue +from supertokens_python.types import User + + +def add_multifactorauth_routes(app: Flask): + @app.route("/test/multifactorauthclaim/fetchvalue", methods=["POST"]) # type: ignore + def fetch_value_api(): # type: ignore + from supertokens_python import convert_to_recipe_user_id + from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, + ) + + assert request.json is not None + response: MFAClaimValue = async_to_sync_wrapper.sync( # type: ignore + MultiFactorAuthClaim.fetch_value( # type: ignore + request.json["_userId"], + convert_to_recipe_user_id(request.json["recipeUserId"]), + request.json["tenantId"], + request.json["currentPayload"], + request.json.get("userContext"), + ) + ) + return jsonify( + { + "c": response.c, # type: ignore + "v": response.v, # type: ignore + } + ) + + @app.route("/test/multifactorauth/getfactorssetupforuser", methods=["POST"]) # type: ignore + def get_factors_setup_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_factors_setup_for_user, + ) + + assert request.json is not None + user_id = request.json["userId"] + user_context = request.json.get("userContext") + + response = get_factors_setup_for_user( + user_id=user_id, + user_context=user_context, + ) + return jsonify(response) + + @app.route("/test/assertallowedtosetupfactorelsethowinvalidclaimerror", methods=["POST"]) # type: ignore + def assert_allowed_to_setup_factor_else_throw_invalid_claim_error_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + assert_allowed_to_setup_factor_else_throw_invalid_claim_error, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + assert_allowed_to_setup_factor_else_throw_invalid_claim_error( + session=session, + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/getmfarequirementsforauth", methods=["POST"]) # type: ignore + def get_mfa_requirements_for_auth_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_mfa_requirements_for_auth, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + response = get_mfa_requirements_for_auth( + session=session, + user_context=request.json.get("userContext"), + ) + + return jsonify(response) + + @app.route("/test/multifactorauth/markfactorascompleteinsession", methods=["POST"]) # type: ignore + def mark_factor_as_complete_in_session_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + mark_factor_as_complete_in_session, + ) + from session import convert_session_to_container + + assert request.json is not None + + session = None + if request.json.get("session"): + session = convert_session_to_container(request.json) + assert session is not None + + mark_factor_as_complete_in_session( + session=session, + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/getrequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def get_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + get_required_secondary_factors_for_user, + ) + + assert request.json is not None + + response = get_required_secondary_factors_for_user( + user_id=request.json["userId"], + user_context=request.json.get("userContext"), + ) + + return jsonify(response) + + @app.route("/test/multifactorauth/addtorequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def add_to_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + add_to_required_secondary_factors_for_user, + ) + + assert request.json is not None + + add_to_required_secondary_factors_for_user( + user_id=request.json["userId"], + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/removefromrequiredsecondaryfactorsforuser", methods=["POST"]) # type: ignore + def remove_from_required_secondary_factors_for_user_api(): # type: ignore + from supertokens_python.recipe.multifactorauth.syncio import ( + remove_from_required_secondary_factors_for_user, + ) + + assert request.json is not None + + remove_from_required_secondary_factors_for_user( + user_id=request.json["userId"], + factor_id=request.json["factorId"], + user_context=request.json.get("userContext"), + ) + + return "", 200 + + @app.route("/test/multifactorauth/recipeimplementation.getmfarequirementsforauth", methods=["POST"]) # type: ignore + def get_mfa_requirements_for_auth_api2(): # type: ignore + assert request.json is not None + from supertokens_python.recipe.multifactorauth.recipe import ( + MultiFactorAuthRecipe, + ) + + async def user() -> User: + assert request.json is not None + user_json = request.json["user"] + assert user_json is not None + return User.from_json(user_json) + + async def factors_set_up_for_user() -> List[str]: + assert request.json is not None + return request.json["factorsSetUpForUser"] + + async def required_secondary_factors_for_user() -> List[str]: + assert request.json is not None + return request.json["requiredSecondaryFactorsForUser"] + + async def required_secondary_factors_for_tenant() -> List[str]: + assert request.json is not None + return request.json["requiredSecondaryFactorsForTenant"] + + response = async_to_sync_wrapper.sync( + MultiFactorAuthRecipe.get_instance_or_throw_error().recipe_implementation.get_mfa_requirements_for_auth( + tenant_id=request.json["tenantId"], + access_token_payload=request.json["accessTokenPayload"], + completed_factors=request.json["completedFactors"], + user=user, + factors_set_up_for_user=factors_set_up_for_user, + required_secondary_factors_for_user=required_secondary_factors_for_user, + required_secondary_factors_for_tenant=required_secondary_factors_for_tenant, + user_context=request.json.get("userContext"), + ) + ) + + return jsonify(response) diff --git a/tests/test-server/multitenancy.py b/tests/test-server/multitenancy.py index 6af638aae..c7ee4c760 100644 --- a/tests/test-server/multitenancy.py +++ b/tests/test-server/multitenancy.py @@ -1,7 +1,7 @@ from flask import Flask, request, jsonify from supertokens_python.recipe.multitenancy.interfaces import ( AssociateUserToTenantOkResult, - TenantConfig, + TenantConfigCreateOrUpdate, ) import supertokens_python.recipe.multitenancy.syncio as multitenancy from supertokens_python.recipe.thirdparty import ( @@ -10,6 +10,7 @@ ProviderInput, ) from supertokens_python.recipe.thirdparty.provider import UserFields, UserInfoMap +from supertokens_python.types import RecipeUserId def add_multitenancy_routes(app: Flask): @@ -19,14 +20,18 @@ def create_or_update_tenant(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data["tenantId"] - config = data["config"] user_context = data.get("userContext") - config = TenantConfig( - email_password_enabled=config.get("emailPasswordEnabled"), - passwordless_enabled=config.get("passwordlessEnabled"), - third_party_enabled=config.get("thirdPartyEnabled"), - core_config=config.get("coreConfig"), + config = ( + TenantConfigCreateOrUpdate( + first_factors=data["config"].get("firstFactors"), + required_secondary_factors=data["config"].get( + "requiredSecondaryFactors" + ), + core_config=data["config"].get("coreConfig", {}), + ) + if "config" in data + else None ) response = multitenancy.create_or_update_tenant(tenant_id, config, user_context) @@ -62,9 +67,9 @@ def get_tenant(): # type: ignore { "status": "OK", "tenant": { - "emailPassword": response.emailpassword.to_json(), - "thirdParty": response.third_party.to_json(), - "passwordless": response.passwordless.to_json(), + "firstFactors": response.first_factors, + "requiredSecondaryFactors": response.required_secondary_factors, + "thirdPartyProviders": response.third_party_providers, "coreConfig": response.core_config, }, } @@ -119,32 +124,34 @@ def create_or_update_third_party_config(): # type: ignore user_info_endpoint_headers=config.get("userInfoEndpointHeaders"), jwks_uri=config.get("jwksURI"), oidc_discovery_endpoint=config.get("oidcDiscoveryEndpoint"), - user_info_map=UserInfoMap( - from_id_token_payload=UserFields( - user_id=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("userId"), - email=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("email"), - email_verified=config.get("userInfoMap", {}) - .get("fromIdTokenPayload", {}) - .get("emailVerified"), - ), - from_user_info_api=UserFields( - user_id=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("userId"), - email=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("email"), - email_verified=config.get("userInfoMap", {}) - .get("fromUserInfoAPI", {}) - .get("emailVerified"), - ), - ) - if "userInfoMap" in config - else None, + user_info_map=( + UserInfoMap( + from_id_token_payload=UserFields( + user_id=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("userId"), + email=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("email"), + email_verified=config.get("userInfoMap", {}) + .get("fromIdTokenPayload", {}) + .get("emailVerified"), + ), + from_user_info_api=UserFields( + user_id=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("userId"), + email=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("email"), + email_verified=config.get("userInfoMap", {}) + .get("fromUserInfoAPI", {}) + .get("emailVerified"), + ), + ) + if "userInfoMap" in config + else None + ), require_email=config.get("requireEmail", True), ) ) @@ -176,11 +183,11 @@ def associate_user_to_tenant(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data["tenantId"] - user_id = data["userId"] + recipe_user_id = RecipeUserId(data["recipeUserId"]) user_context = data.get("userContext") response = multitenancy.associate_user_to_tenant( - tenant_id, user_id, user_context + tenant_id, recipe_user_id, user_context ) if isinstance(response, AssociateUserToTenantOkResult): @@ -202,7 +209,7 @@ def disassociate_user_from_tenant(): # type: ignore user_id = data["userId"] user_context = data.get("userContext") - response = multitenancy.dissociate_user_from_tenant( + response = multitenancy.disassociate_user_from_tenant( tenant_id, user_id, user_context ) diff --git a/tests/test-server/override_logging.py b/tests/test-server/override_logging.py index a7dab1b76..03a84328d 100644 --- a/tests/test-server/override_logging.py +++ b/tests/test-server/override_logging.py @@ -1,9 +1,53 @@ -from typing import Any, Dict, List, Set, Union +import json +from typing import Any, Callable, Coroutine, Dict, List, Set, Union import time from httpx import Response from supertokens_python.framework.flask.flask_request import FlaskRequest +from supertokens_python.recipe.accountlinking import RecipeLevelUser +from supertokens_python.recipe.accountlinking.interfaces import ( + CreatePrimaryUserOkResult, + LinkAccountsOkResult, +) +from supertokens_python.recipe.accountlinking.types import AccountInfoWithRecipeId +from supertokens_python.recipe.emailpassword.types import ( + FormField, + PasswordResetEmailTemplateVars, +) +from supertokens_python.recipe.emailpassword.interfaces import ( + APIOptions as EmailPasswordAPIOptions, + ConsumePasswordResetTokenOkResult, + CreateResetPasswordOkResult, + GeneratePasswordResetTokenPostOkResult, + PasswordResetPostOkResult, + SignUpOkResult, + SignUpPostOkResult, + UpdateEmailOrPasswordOkResult, +) +from supertokens_python.recipe.emailverification.interfaces import ( + CreateEmailVerificationTokenEmailAlreadyVerifiedError, + CreateEmailVerificationTokenOkResult, + GetEmailForUserIdOkResult, + VerifyEmailUsingTokenOkResult, +) +from supertokens_python.recipe.emailverification.recipe import IsVerifiedSCV +from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.recipe.session.interfaces import ( + ClaimsValidationResult, + RegenerateAccessTokenOkResult, + SessionInformationResult, +) +from supertokens_python.recipe.session.session_class import Session +from supertokens_python.recipe.thirdparty.interfaces import ( + APIOptions as ThirdPartyAPIOptions, +) +from supertokens_python.recipe.passwordless.interfaces import ( + APIOptions as PasswordlessAPIOptions, +) +from supertokens_python.recipe.thirdparty.provider import ProviderConfigForClient +from supertokens_python.recipe.thirdparty.types import UserInfo as TPUserInfo +from supertokens_python.types import AccountInfo, RecipeUserId, User override_logs: List[Dict[str, Any]] = [] @@ -40,5 +84,83 @@ def transform_logged_data(data: Any, visited: Union[Set[Any], None] = None) -> A return "FlaskRequest" if isinstance(data, Response): return "Response" + if isinstance(data, RecipeUserId): + return data.get_as_string() + if isinstance(data, AccountInfoWithRecipeId): + return data.to_json() + if isinstance(data, AccountInfo): + return data.to_json() + if isinstance(data, User): + return data.to_json() + if isinstance(data, Coroutine): + return "Coroutine" + if isinstance(data, Callable): + return "Callable" + if isinstance(data, FormField): + return data.to_json() + if isinstance(data, EmailPasswordAPIOptions): + return "EmailPasswordAPIOptions" + if isinstance(data, ThirdPartyAPIOptions): + return "ThirdPartyAPIOptions" + if isinstance(data, PasswordlessAPIOptions): + return "PasswordlessAPIOptions" + if isinstance(data, SignUpOkResult): + return data.to_json() + if isinstance(data, CreatePrimaryUserOkResult): + return data.to_json() + if isinstance(data, LinkAccountsOkResult): + return data.to_json() + if isinstance(data, Session): + from session import convert_session_to_json - return data + return convert_session_to_json(data) + if isinstance(data, SignUpPostOkResult): + return data.to_json() + if isinstance(data, ClaimsValidationResult): + return data.to_json() + if isinstance(data, ProviderConfigForClient): + return data.to_json() + if isinstance(data, TPUserInfo): + return data.to_json() + if isinstance(data, GeneratePasswordResetTokenPostOkResult): + return data.to_json() + if isinstance(data, CreateEmailVerificationTokenOkResult): + return {"token": data.token, "status": data.status} + if isinstance(data, GetEmailForUserIdOkResult): + return {"email": data.email, "status": "OK"} + if isinstance(data, VerifyEmailUsingTokenOkResult): + return {"status": data.status} + if isinstance(data, CreateResetPasswordOkResult): + return {"token": data.token, "status": "OK"} + if isinstance(data, PasswordResetEmailTemplateVars): + return data.to_json() + if isinstance(data, ConsumePasswordResetTokenOkResult): + return data.to_json() + if isinstance(data, UpdateEmailOrPasswordOkResult): + return {"status": "OK"} + if isinstance(data, CreateEmailVerificationTokenEmailAlreadyVerifiedError): + return {"status": "EMAIL_ALREADY_VERIFIED_ERROR"} + if isinstance(data, PasswordResetPostOkResult): + return {"status": "OK", "user": data.user.to_json(), "email": data.email} + if isinstance(data, RecipeLevelUser): + return data.to_json() + if isinstance(data, RegenerateAccessTokenOkResult): + return data.to_json() + if isinstance(data, PrimitiveClaim): + return "PrimitiveClaim" + if isinstance(data, SessionInformationResult): + return data.to_json() + if isinstance(data, IsVerifiedSCV): + return "IsVerifiedSCV" + if is_jsonable(data): + return data + + return "Some custom object" + + +def is_jsonable(x: Any) -> bool: + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False diff --git a/tests/test-server/passwordless.py b/tests/test-server/passwordless.py new file mode 100644 index 000000000..eb88607d2 --- /dev/null +++ b/tests/test-server/passwordless.py @@ -0,0 +1,171 @@ +from flask import Flask, request, jsonify +from supertokens_python import convert_to_recipe_user_id +from supertokens_python.recipe.passwordless.interfaces import ( + ConsumeCodeExpiredUserInputCodeError, + ConsumeCodeIncorrectUserInputCodeError, + ConsumeCodeOkResult, + ConsumeCodeRestartFlowError, + EmailChangeNotAllowedError, + UpdateUserEmailAlreadyExistsError, + UpdateUserOkResult, + UpdateUserPhoneNumberAlreadyExistsError, + UpdateUserUnknownUserIdError, +) +from supertokens_python.recipe.passwordless.syncio import ( + signinup, + create_code, + update_user, + consume_code, +) +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error +from session import convert_session_to_container # pylint: disable=import-error + + +def add_passwordless_routes(app: Flask): + @app.route("/test/passwordless/signinup", methods=["POST"]) # type: ignore + def sign_in_up_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = signinup( + email=body.get("email", None), + phone_number=body.get("phoneNumber", None), + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user(response.user, request.headers.get("fdi-version", "")), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + + @app.route("/test/passwordless/createcode", methods=["POST"]) # type: ignore + def create_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = create_code( + email=body.get("email"), + phone_number=body.get("phoneNumber"), + tenant_id=body.get("tenantId", "public"), + user_input_code=body.get("userInputCode"), + user_context=body.get("userContext"), + session=session, + ) + return jsonify( + { + "status": "OK", + "codeId": response.code_id, + "preAuthSessionId": response.pre_auth_session_id, + "codeLifeTime": response.code_life_time, + "deviceId": response.device_id, + "linkCode": response.link_code, + "timeCreated": response.time_created, + "userInputCode": response.user_input_code, + } + ) + + @app.route("/test/passwordless/consumecode", methods=["POST"]) # type: ignore + def consume_code_api(): # type: ignore + assert request.json is not None + body = request.json + session = None + if "session" in body: + session = convert_session_to_container(body) + + response = consume_code( + device_id=body.get("deviceId"), + pre_auth_session_id=body.get("preAuthSessionId"), + user_input_code=body.get("userInputCode"), + link_code=body.get("linkCode", None), + tenant_id=body.get("tenantId", "public"), + user_context=body.get("userContext"), + session=session, + ) + + if isinstance(response, ConsumeCodeOkResult): + return jsonify( + { + "status": "OK", + "createdNewRecipeUser": response.created_new_recipe_user, + "consumedDevice": response.consumed_device.to_json(), + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, ConsumeCodeIncorrectUserInputCodeError): + return jsonify( + { + "status": "INCORRECT_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeExpiredUserInputCodeError): + return jsonify( + { + "status": "EXPIRED_USER_INPUT_CODE_ERROR", + "failedCodeInputAttemptCount": response.failed_code_input_attempt_count, + "maximumCodeInputAttempts": response.maximum_code_input_attempts, + } + ) + elif isinstance(response, ConsumeCodeRestartFlowError): + return jsonify({"status": "RESTART_FLOW_ERROR"}) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) + + @app.route("/test/passwordless/updateuser", methods=["POST"]) # type: ignore + def update_user_api(): # type: ignore + assert request.json is not None + body = request.json + response = update_user( + recipe_user_id=convert_to_recipe_user_id(body["recipeUserId"]), + email=body.get("email"), + phone_number=body.get("phoneNumber"), + user_context=body.get("userContext"), + ) + + if isinstance(response, UpdateUserOkResult): + return jsonify({"status": "OK"}) + elif isinstance(response, UpdateUserUnknownUserIdError): + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + elif isinstance(response, UpdateUserEmailAlreadyExistsError): + return jsonify({"status": "EMAIL_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, UpdateUserPhoneNumberAlreadyExistsError): + return jsonify({"status": "PHONE_NUMBER_ALREADY_EXISTS_ERROR"}) + elif isinstance(response, EmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) + else: + return jsonify( + { + "status": "PHONE_NUMBER_CHANGE_NOT_ALLOWED_ERROR", + "reason": response.reason, + } + ) diff --git a/tests/test-server/session.py b/tests/test-server/session.py index df2f1cc7f..cf1dc174e 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,9 +1,24 @@ +from typing import Any, Dict from flask import Flask, request, jsonify -from utils import deserialize_validator -from supertokens_python import async_to_sync_wrapper +from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import TokenTheftError +from supertokens_python.recipe.session.interfaces import ( + SessionDoesNotExistError, + TokenInfo, +) +from supertokens_python.recipe.session.jwt import ( + parse_jwt_without_signature_verification, +) +from supertokens_python.types import RecipeUserId +from utils import ( # pylint: disable=import-error + deserialize_validator, + get_max_version, +) from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.recipe.session.session_class import Session import supertokens_python.recipe.session.syncio as session +from utils import deserialize_claim # pylint: disable=import-error def add_session_routes(app: Flask): @@ -14,7 +29,18 @@ def create_new_session_without_request_response(): # type: ignore return jsonify({"status": "MISSING_DATA_ERROR"}) tenant_id = data.get("tenantId", "public") - user_id = data["userId"] + from supertokens_python import convert_to_recipe_user_id + + fdi_version = request.headers.get("fdi-version") + assert fdi_version is not None + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + # fdi_version <= "1.17" or (fdi_version >= "2.0" and fdi_version < "3.0") + recipe_user_id = convert_to_recipe_user_id(data["userId"]) + else: + recipe_user_id = convert_to_recipe_user_id(data["recipeUserId"]) access_token_payload = data.get("accessTokenPayload", {}) session_data_in_database = data.get("sessionDataInDatabase", {}) disable_anti_csrf = data.get("disableAntiCsrf") @@ -22,34 +48,85 @@ def create_new_session_without_request_response(): # type: ignore session_container = session.create_new_session_without_request_response( tenant_id, - user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, user_context, ) - return jsonify( - { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - } + return jsonify(convert_session_to_json(session_container)) + + @app.route("/test/session/getallsessionhandlesforuser", methods=["POST"]) # type: ignore + def get_all_session_handles_for_user_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + user_id = data["userId"] + fetch_sessions_for_all_linked_accounts = data.get( + "fetchSessionsForAllLinkedAccounts", True + ) + tenant_id = data.get("tenantId", "public") + user_context = data.get("userContext", {}) + + response = session.get_all_session_handles_for_user( + user_id, fetch_sessions_for_all_linked_accounts, tenant_id, user_context ) + return jsonify(response) + + @app.route("/test/session/revokeallsessionsforuser", methods=["POST"]) # type: ignore + def revoke_all_sessions_for_user_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + user_id = data["userId"] + revoke_sessions_for_linked_accounts = data.get( + "revokeSessionsForLinkedAccounts", True + ) + tenant_id = data.get("tenantId", None) + user_context = data.get("userContext", {}) + + response = session.revoke_all_sessions_for_user( + user_id, revoke_sessions_for_linked_accounts, tenant_id, user_context + ) + return jsonify(response) + + @app.route("/test/session/refreshsessionwithoutrequestresponse", methods=["POST"]) # type: ignore + def refresh_session_without_request_response(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + refresh_token = data["refreshToken"] + disable_anti_csrf = data.get("disableAntiCsrf") + anti_csrf_token = data.get("antiCsrfToken") + user_context = data.get("userContext", {}) + + try: + response = session.refresh_session_without_request_response( + refresh_token, disable_anti_csrf, anti_csrf_token, user_context + ) + return jsonify(convert_session_to_json(response)) + except Exception as e: + if isinstance(e, TokenTheftError): + return ( + jsonify( + { + "type": "TOKEN_THEFT_DETECTED", + "payload": { + "recipeUserId": { + # this is done this way cause the frontend test suite expects the json in this format + "recipeUserId": e.recipe_user_id.get_as_string() + }, + "userId": e.user_id, + }, + } + ), + 500, + ) + return jsonify({"message": str(e)}), 500 @app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore def get_session_without_request_response(): # type: ignore @@ -62,13 +139,14 @@ def get_session_without_request_response(): # type: ignore options = data.get("options") user_context = data.get("userContext", {}) - try: - session_container = session.get_session_without_request_response( - access_token, anti_csrf_token, options, user_context - ) - return jsonify(session_container) - except Exception as e: - return jsonify({"error": str(e)}), 500 + session_container = session.get_session_without_request_response( + access_token, anti_csrf_token, options, user_context + ) + return jsonify( + None + if session_container is None + else convert_session_to_json(session_container) + ) @app.route("/test/session/sessionobject/assertclaims", methods=["POST"]) # type: ignore def assert_claims(): # type: ignore @@ -76,50 +154,17 @@ def assert_claims(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) claim_validators = list(map(deserialize_validator, data["claimValidators"])) user_context = data.get("userContext", {}) try: - async_to_sync_wrapper.sync( - session_container.assert_claims(claim_validators, user_context) - ) + session_container.sync_assert_claims(claim_validators, user_context) return jsonify( { "status": "OK", - "updatedSession": { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - }, + "updatedSession": convert_session_to_json(session_container), } ) except Exception as e: @@ -134,50 +179,221 @@ def merge_into_access_token_payload_on_session_object(): # type: ignore if data is None: return jsonify({"status": "MISSING_DATA_ERROR"}) - session_container = Session( - recipe_implementation=SessionRecipe.get_instance().recipe_implementation, - config=SessionRecipe.get_instance().config, - access_token=data["session"]["accessToken"], - front_token=data["session"]["frontToken"], - refresh_token=None, # We don't have refresh token in the input - anti_csrf_token=None, # We don't have anti-csrf token in the input - session_handle=data["session"]["sessionHandle"], - user_id=data["session"]["userId"], - user_data_in_access_token=data["session"]["userDataInAccessToken"], - req_res_info=None, # We don't have this information in the input - access_token_updated=data["session"]["accessTokenUpdated"], - tenant_id=data["session"]["tenantId"], - ) + session_container = convert_session_to_container(data) + access_token_payload_update = data["accessTokenPayloadUpdate"] user_context = data.get("userContext", {}) - async_to_sync_wrapper.sync( - session_container.merge_into_access_token_payload( - access_token_payload_update, user_context - ) + session_container.sync_merge_into_access_token_payload( + access_token_payload_update, user_context ) return jsonify( { "status": "OK", - "updatedSession": { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - }, + "updatedSession": convert_session_to_json(session_container), } ) + + @app.route("/test/session/mergeintoaccesspayload", methods=["POST"]) # type: ignore + def merge_into_access_payload(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + access_token_payload_update = data["accessTokenPayloadUpdate"] + user_context = data.get("userContext", {}) + + try: + response = session.merge_into_access_token_payload( + session_handle, access_token_payload_update, user_context + ) + return jsonify(response) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + + @app.route("/test/session/validateclaimsforsessionhandle", methods=["POST"]) # type: ignore + def validate_claims_for_session_handle(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + override_global_claim_validators = None + if "overrideGlobalClaimValidators" in data: + from test_functions_mapper import get_func + + override_global_claim_validators = get_func( + data["overrideGlobalClaimValidators"] + ) + user_context = data.get("userContext", {}) + + try: + response = session.validate_claims_for_session_handle( + session_handle, override_global_claim_validators, user_context + ) + if isinstance(response, SessionDoesNotExistError): + return jsonify({"status": "SESSION_DOES_NOT_EXIST"}) + return jsonify(response.to_json()) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + + @app.route("/test/session/getsessioninformation", methods=["POST"]) # type: ignore + def get_session_information_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + session_handle = data["sessionHandle"] + user_context = data.get("userContext", {}) + + response = session.get_session_information(session_handle, user_context) + if response is None: + return jsonify(None) + return jsonify( + { + "customClaimsInAccessTokenPayload": response.custom_claims_in_access_token_payload, + "sessionDataInDatabase": response.session_data_in_database, + "expiry": response.expiry, + "sessionHandle": response.session_handle, + "recipeUserId": response.recipe_user_id.get_as_string(), + "tenantId": response.tenant_id, + "timeCreated": response.time_created, + "userId": response.user_id, + } + ) + + @app.route("/test/session/sessionobject/fetchandsetclaim", methods=["POST"]) # type: ignore + def session_object_fetch_and_set_claim_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.fetchandsetclaim", "CALL", data) + session = convert_session_to_container(data) + + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + session.sync_fetch_and_set_claim(claim, user_context) + response = {"updatedSession": convert_session_to_json(session)} + return jsonify(response) + + @app.route("/test/session/fetchandsetclaim", methods=["POST"]) # type: ignore + def fetch_and_set_claim_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("session.fetchandsetclaim", "CALL", data) + session_handle = data["sessionHandle"] + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + try: + response = session.fetch_and_set_claim(session_handle, claim, user_context) + return jsonify(response) + except Exception as e: + return jsonify({"status": "ERROR", "message": str(e)}), 500 + + @app.route("/test/session/sessionobject/getclaimvalue", methods=["POST"]) # type: ignore + def get_claim_value_api(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + log_override_event("sessionobject.getclaimvalue", "CALL", data) + session = convert_session_to_container(data) + + claim = deserialize_claim(data["claim"]) + user_context = data.get("userContext", {}) + + try: + ret_val = session.sync_get_claim_value(claim, user_context) + response = { + "retVal": ret_val, + "updatedSession": convert_session_to_json(session), + } + log_override_event("sessionobject.getclaimvalue", "RES", ret_val) + return jsonify(response) + except Exception as e: + log_override_event("sessionobject.getclaimvalue", "REJ", str(e)) + raise e + + +def convert_session_to_json(session_container: SessionContainer) -> Dict[str, Any]: + return { + "sessionHandle": session_container.get_handle(), + "userId": session_container.get_user_id(), + "tenantId": session_container.get_tenant_id(), + "userDataInAccessToken": session_container.get_access_token_payload(), + "accessToken": session_container.get_access_token(), + "frontToken": session_container.get_all_session_tokens_dangerously()[ + "frontToken" + ], + "refreshToken": ( + session_container.refresh_token.to_json() + if session_container.refresh_token is not None + else None + ), + "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ + "antiCsrfToken" + ], + "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ + "accessAndFrontTokenUpdated" + ], + "recipeUserId": { + # this is intentionally done this way cause the test in the test suite expects this way. + "recipeUserId": session_container.get_recipe_user_id().get_as_string() + }, + } + + +def convert_session_to_container(data: Any) -> Session: + jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) + jwt_payload = jwt_info.payload + + user_id = jwt_payload["userId"] if jwt_info.version == 2 else jwt_payload["sub"] + session_handle = jwt_payload["sessionHandle"] + + recipe_user_id = RecipeUserId(jwt_payload.get("rsub", user_id)) + anti_csrf_token = jwt_payload.get("antiCsrfToken") + tenant_id = jwt_payload["tId"] if jwt_info.version >= 4 else "public" + + return Session( + recipe_implementation=SessionRecipe.get_instance().recipe_implementation, + config=SessionRecipe.get_instance().config, + access_token=data["session"]["accessToken"], + front_token=data["session"]["frontToken"], + refresh_token=( + TokenInfo( + ( + data["session"]["refreshToken"] + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["token"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["expiry"] + ), + ( + -1 + if isinstance(data["session"]["refreshToken"], str) + else data["session"]["refreshToken"]["createdTime"] + ), + ) + if "refreshToken" in data["session"] + and data["session"]["refreshToken"] is not None + else None + ), + anti_csrf_token=anti_csrf_token, + session_handle=session_handle, + user_id=user_id, + recipe_user_id=recipe_user_id, + user_data_in_access_token=jwt_payload, + req_res_info=None, # We don't have this information in the input + access_token_updated=False, + tenant_id=tenant_id, + ) diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py new file mode 100644 index 000000000..a570e149a --- /dev/null +++ b/tests/test-server/supertokens.py @@ -0,0 +1,87 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import AccountInfo +from supertokens_python.syncio import ( + get_user, + delete_user, + list_users_by_account_info, + get_users_newest_first, + get_users_oldest_first, +) + + +def add_supertokens_routes(app: Flask): + @app.route("/test/supertokens/getuser", methods=["POST"]) # type: ignore + def get_user_api(): # type: ignore + assert request.json is not None + response = get_user(request.json["userId"], request.json.get("userContext")) + return jsonify(None if response is None else response.to_json()) + + @app.route("/test/supertokens/deleteuser", methods=["POST"]) # type: ignore + def delete_user_api(): # type: ignore + assert request.json is not None + delete_user( + request.json["userId"], + request.json.get("removeAllLinkedAccounts", True), + request.json.get("userContext"), + ) + return jsonify({"status": "OK"}) + + @app.route("/test/supertokens/listusersbyaccountinfo", methods=["POST"]) # type: ignore + def list_users_by_account_info_api(): # type: ignore + assert request.json is not None + response = list_users_by_account_info( + request.json["tenantId"], + AccountInfo( + email=request.json["accountInfo"].get("email", None), + phone_number=request.json["accountInfo"].get("phoneNumber", None), + third_party=( + None + if "thirdParty" not in request.json["accountInfo"] + else ThirdPartyInfo( + third_party_id=request.json["accountInfo"]["thirdParty"]["id"], + third_party_user_id=request.json["accountInfo"]["thirdParty"][ + "userId" + ], + ) + ), + ), + request.json.get("doUnionOfAccountInfo", False), + request.json.get("userContext"), + ) + + return jsonify([r.to_json() for r in response]) + + @app.route("/test/supertokens/getusersnewestfirst", methods=["POST"]) # type: ignore + def get_users_newest_first_api(): # type: ignore + assert request.json is not None + response = get_users_newest_first( + include_recipe_ids=request.json.get("includeRecipeIds"), + limit=request.json.get("limit"), + pagination_token=request.json.get("paginationToken"), + tenant_id=request.json.get("tenantId"), + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) + + @app.route("/test/supertokens/getusersoldestfirst", methods=["POST"]) # type: ignore + def get_users_oldest_first_api(): # type: ignore + assert request.json is not None + response = get_users_oldest_first( + include_recipe_ids=request.json.get("includeRecipeIds"), + limit=request.json.get("limit"), + pagination_token=request.json.get("paginationToken"), + tenant_id=request.json.get("tenantId"), + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 97a417d58..af5651978 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -1,11 +1,63 @@ -from typing import Callable +from typing import Callable, List, Union +from typing import Dict, Any, Optional +from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.auth_utils import LinkingToSessionUserFailedError +from supertokens_python.recipe.accountlinking import ( + RecipeLevelUser, + ShouldAutomaticallyLink, + ShouldNotAutomaticallyLink, +) +from supertokens_python.recipe.dashboard.interfaces import APIOptions +from supertokens_python.recipe.emailpassword.interfaces import ( + EmailAlreadyExistsError, + PasswordPolicyViolationError, + PasswordResetPostOkResult, + PasswordResetTokenInvalidError, + SignUpPostNotAllowedResponse, + SignUpPostOkResult, +) +from supertokens_python.recipe.emailpassword.types import ( + EmailDeliveryOverrideInput, + EmailTemplateVars, + FormField, +) +from supertokens_python.recipe.emailverification.interfaces import ( + EmailDoesNotExistError, + GetEmailForUserIdOkResult, +) +from supertokens_python.recipe.emailverification.types import ( + VerificationEmailTemplateVarsUser, +) +from supertokens_python.recipe.multifactorauth.interfaces import ( + ResyncSessionAndFetchMFAInfoPUTOkResult, +) +from supertokens_python.recipe.multifactorauth.types import MFARequirementList +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.recipe.thirdparty.interfaces import ( + SignInUpNotAllowed, + SignInUpOkResult, + SignInUpPostNoEmailGivenByProviderResponse, + SignInUpPostOkResult, +) +from supertokens_python.recipe.thirdparty.provider import Provider, RedirectUriInfo +from supertokens_python.recipe.thirdparty.types import ( + RawUserInfoFromProvider, + UserInfo, + UserInfoEmail, +) +from supertokens_python.types import AccountInfo, GeneralErrorResponse, RecipeUserId +from supertokens_python.types import APIResponse, User class Info: core_call_count = 0 -def get_func(eval_str: str) -> Callable: # type: ignore +def get_func(eval_str: str) -> Callable[..., Any]: + global store # pylint: disable=global-variable-not-assigned + global send_email_inputs # pylint: disable=global-variable-not-assigned + global send_sms_inputs # pylint: disable=global-variable-not-assigned if eval_str.startswith("supertokens.init.supertokens.networkInterceptor"): def func(*args): # type: ignore @@ -14,4 +66,955 @@ def func(*args): # type: ignore return func # type: ignore - raise Exception("Unknown eval string") + elif eval_str.startswith("accountlinking.init.onAccountLinked"): + + async def on_account_linked( + user: User, recipe_level_user: RecipeLevelUser, user_context: Dict[str, Any] + ) -> None: + global primary_user_in_callback + global new_account_info_in_callback + primary_user_in_callback = user + new_account_info_in_callback = recipe_level_user + + return on_account_linked + + elif eval_str.startswith("multifactorauth.init.override.apis"): + from supertokens_python.recipe.multifactorauth.interfaces import ( + APIInterface as MFAAPIInterface, + APIOptions as MFAAPIOptions, + ) + + def mfa_override_apis( + original_implementation: MFAAPIInterface, + ) -> MFAAPIInterface: + original_resync_session_and_fetch_mfa_info_put = ( + original_implementation.resync_session_and_fetch_mfa_info_put + ) + + async def resync_session_and_fetch_mfa_info_put( + api_options: MFAAPIOptions, + session: SessionContainer, + user_context: Dict[str, Any], + ) -> Union[ResyncSessionAndFetchMFAInfoPUTOkResult, GeneralErrorResponse]: + json_body = await api_options.request.json() + if ( + json_body is not None + and json_body.get("userContext", {}).get("requireFactor") + is not None + ): + user_context["requireFactor"] = json_body["userContext"][ + "requireFactor" + ] + + return await original_resync_session_and_fetch_mfa_info_put( + api_options, session, user_context + ) + + original_implementation.resync_session_and_fetch_mfa_info_put = ( + resync_session_and_fetch_mfa_info_put + ) + return original_implementation + + return mfa_override_apis + + elif eval_str.startswith("multifactorauth.init.override.functions"): + from supertokens_python.recipe.multifactorauth.interfaces import ( + RecipeInterface as MFARecipeInterface, + ) + + def mfa_override_functions( + original_implementation: MFARecipeInterface, + ) -> MFARecipeInterface: + async def get_mfa_requirements_for_auth( + tenant_id: str, + access_token_payload: Dict[str, Any], + completed_factors: Dict[str, int], + user: Any, + factors_set_up_for_user: Any, + required_secondary_factors_for_user: Any, + required_secondary_factors_for_tenant: Any, + user_context: Dict[str, Any], + ) -> MFARequirementList: + return ["otp-phone"] if user_context.get("requireFactor") else [] + + original_implementation.get_mfa_requirements_for_auth = ( + get_mfa_requirements_for_auth + ) + return original_implementation + + return mfa_override_functions + + elif eval_str.startswith("emailverification.init.emailDelivery.override"): + from supertokens_python.recipe.emailverification.types import ( + EmailDeliveryOverrideInput as EVEmailDeliveryOverrideInput, + EmailTemplateVars as EVEmailTemplateVars, + ) + + def custom_email_delivery_override( + original_implementation: EVEmailDeliveryOverrideInput, + ) -> EVEmailDeliveryOverrideInput: + original_send_email = original_implementation.send_email + + async def send_email( + template_vars: EVEmailTemplateVars, user_context: Dict[str, Any] + ) -> None: + global user_in_callback # pylint: disable=global-variable-not-assigned + global token # pylint: disable=global-variable-not-assigned + + if template_vars.user: + user_in_callback = template_vars.user + + if template_vars.email_verify_link: + token = template_vars.email_verify_link.split("?token=")[1].split( + "&tenantId=" + )[0] + + # Call the original implementation + await original_send_email(template_vars, user_context) + + original_implementation.send_email = send_email + return original_implementation + + return custom_email_delivery_override + + elif eval_str.startswith("session.override.functions"): + from supertokens_python.recipe.session.interfaces import ( + RecipeInterface as SessionRecipeInterface, + ) + + def session_override_functions( + original_implementation: SessionRecipeInterface, + ) -> SessionRecipeInterface: + original_create_new_session = original_implementation.create_new_session + + async def create_new_session( + user_id: str, + recipe_user_id: RecipeUserId, + access_token_payload: Optional[Dict[str, Any]], + session_data_in_database: Optional[Dict[str, Any]], + disable_anti_csrf: Optional[bool], + tenant_id: str, + user_context: Dict[str, Any], + ) -> SessionContainer: + async def fetch_value( + _user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> None: + global user_id_in_callback + global recipe_user_id_in_callback + user_id_in_callback = user_id + recipe_user_id_in_callback = recipe_user_id + return None + + claim = PrimitiveClaim[Any](key="some-key", fetch_value=fetch_value) + + if access_token_payload is None: + access_token_payload = {} + json_update = await claim.build( + user_id, + recipe_user_id, + tenant_id, + access_token_payload, + user_context, + ) + access_token_payload.update(json_update) + + return await original_create_new_session( + user_id, + recipe_user_id, + access_token_payload, + session_data_in_database, + disable_anti_csrf, + tenant_id, + user_context, + ) + + original_implementation.create_new_session = create_new_session + return original_implementation + + return session_override_functions + + elif eval_str.startswith("emailpassword.init.emailDelivery.override"): + + def custom_email_deliver( + original_implementation: EmailDeliveryOverrideInput, + ) -> EmailDeliveryOverrideInput: + original_send_email = original_implementation.send_email + + async def send_email( + template_vars: EmailTemplateVars, user_context: Dict[str, Any] + ) -> None: + global send_email_callback_called # pylint: disable=global-variable-not-assigned + global send_email_to_user_id # pylint: disable=global-variable-not-assigned + global send_email_to_user_email # pylint: disable=global-variable-not-assigned + global send_email_to_recipe_user_id # pylint: disable=global-variable-not-assigned + global token # pylint: disable=global-variable-not-assigned + send_email_callback_called = True + + if template_vars.user: + send_email_to_user_id = template_vars.user.id + + if template_vars.user.email: + send_email_to_user_email = template_vars.user.email + + if template_vars.user.recipe_user_id: + send_email_to_recipe_user_id = ( + template_vars.user.recipe_user_id.get_as_string() + ) + + if template_vars.password_reset_link: + token = ( + template_vars.password_reset_link.split("?")[1] + .split("&")[0] + .split("=")[1] + ) + + # Use the original implementation which calls the default service, + # or a service that you may have specified in the email_delivery object. + return await original_send_email(template_vars, user_context) + + original_implementation.send_email = send_email + return original_implementation + + return custom_email_deliver + elif eval_str.startswith("passwordless.init.emailDelivery.service.sendEmail"): + + def func1( + template_vars: Any, + user_context: Dict[str, Any], # pylint: disable=unused-argument + ) -> None: # pylint: disable=unused-argument + # Add to store + jsonified = { + "codeLifeTime": template_vars.code_life_time, + "email": template_vars.email, + "isFirstFactor": template_vars.is_first_factor, + "preAuthSessionId": template_vars.pre_auth_session_id, + "tenantId": template_vars.tenant_id, + "urlWithLinkCode": template_vars.url_with_link_code, + "userInputCode": template_vars.user_input_code, + } + jsonified = {k: v for k, v in jsonified.items() if v is not None} + if "emailInputs" in store: + store["emailInputs"].append(jsonified) + else: + store["emailInputs"] = [jsonified] + + # Add to send_email_inputs + send_email_inputs.append(jsonified) + + return func1 + + if eval_str.startswith("thirdparty.init.override.functions"): + if "setIsVerifiedInSignInUp" in eval_str: + from supertokens_python.recipe.thirdparty.interfaces import ( + RecipeInterface as ThirdPartyRecipeInterface, + ) + + def custom_override( + original_implementation: ThirdPartyRecipeInterface, + ) -> ThirdPartyRecipeInterface: + og_sign_in_up = original_implementation.sign_in_up + + async def sign_in_up( + third_party_id: str, + third_party_user_id: str, + email: str, + is_verified: bool, + oauth_tokens: Dict[str, Any], + raw_user_info_from_provider: RawUserInfoFromProvider, + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + user_context: Dict[str, Any], + ) -> Union[ + SignInUpOkResult, + SignInUpNotAllowed, + LinkingToSessionUserFailedError, + ]: + user_context[ + "isVerified" + ] = is_verified # this information comes from the third party provider + return await og_sign_in_up( + third_party_id, + third_party_user_id, + email, + is_verified, + oauth_tokens, + raw_user_info_from_provider, + session, + should_try_linking_with_session_user, + tenant_id, + user_context, + ) + + original_implementation.sign_in_up = sign_in_up + return original_implementation + + return custom_override + + elif eval_str.startswith("passwordless.init.smsDelivery.service.sendSms"): + + def func2( + template_vars: Any, user_context: Dict[str, Any] + ) -> None: # pylint: disable=unused-argument + jsonified = { + "codeLifeTime": template_vars.code_life_time, + "phoneNumber": template_vars.phone_number, + "isFirstFactor": template_vars.is_first_factor, + "preAuthSessionId": template_vars.pre_auth_session_id, + "tenantId": template_vars.tenant_id, + "urlWithLinkCode": template_vars.url_with_link_code, + "userInputCode": template_vars.user_input_code, + } + jsonified = {k: v for k, v in jsonified.items() if v is not None} + send_sms_inputs.append(jsonified) + + return func2 + + elif eval_str.startswith("passwordless.init.override.apis"): + + def func3(oI: Any) -> Dict[str, Any]: + og = oI.consume_code_post + + async def consume_code_post( + pre_auth_session_id: str, + user_input_code: Union[str, None], + device_id: Union[str, None], + link_code: Union[str, None], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + api_options: APIOptions, + user_context: Dict[str, Any], + ) -> Any: + o = await api_options.request.json() + assert o is not None + if o.get("userContext", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = o["userContext"]["DO_LINK"] + return await og( + pre_auth_session_id, + user_input_code, + device_id, + link_code, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, + ) + + oI.consume_code_post = consume_code_post + return oI + + return func3 + + elif eval_str.startswith("emailpassword.init.override.apis"): + from supertokens_python.recipe.emailpassword.interfaces import ( + APIInterface as EmailPasswordAPIInterface, + APIOptions as EmailPasswordAPIOptions, + ) + + def ep_override_apis( + original_implementation: EmailPasswordAPIInterface, + ) -> EmailPasswordAPIInterface: + + og_password_reset_post = original_implementation.password_reset_post + og_sign_up_post = original_implementation.sign_up_post + + async def password_reset_post( + form_fields: List[FormField], + token: str, + tenant_id: str, + api_options: EmailPasswordAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + PasswordResetPostOkResult, + PasswordResetTokenInvalidError, + PasswordPolicyViolationError, + GeneralErrorResponse, + ]: + if "DO_NOT_LINK" in eval_str: + user_context["DO_NOT_LINK"] = True + t = await og_password_reset_post( + form_fields, token, tenant_id, api_options, user_context + ) + if isinstance(t, PasswordResetPostOkResult): + global email_post_password_reset, user_post_password_reset + email_post_password_reset = t.email + user_post_password_reset = t.user + return t + + async def sign_up_post( + form_fields: List[FormField], + tenant_id: str, + session: Union[SessionContainer, None], + should_try_linking_with_session_user: Union[bool, None], + api_options: EmailPasswordAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + SignUpPostOkResult, + EmailAlreadyExistsError, + SignUpPostNotAllowedResponse, + GeneralErrorResponse, + ]: + if "signUpPOST" in eval_str: + n = await api_options.request.json() + assert n is not None + if n.get("userContext", {}).get("DO_LINK") is not None: + user_context["DO_LINK"] = n["userContext"]["DO_LINK"] + return await og_sign_up_post( + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, + ) + + original_implementation.password_reset_post = password_reset_post + original_implementation.sign_up_post = sign_up_post + return original_implementation + + return ep_override_apis + + elif eval_str.startswith("emailverification.init.override.functions"): + from supertokens_python.recipe.emailverification.interfaces import ( + RecipeInterface as EmailVerificationRecipeInterface, + ) + + def ev_override_functions( + original_implementation: EmailVerificationRecipeInterface, + ) -> EmailVerificationRecipeInterface: + og_is_email_verified = original_implementation.is_email_verified + + async def is_email_verified( + recipe_user_id: RecipeUserId, email: str, user_context: Dict[str, Any] + ) -> bool: + global email_param + email_param = email + return await og_is_email_verified(recipe_user_id, email, user_context) + + original_implementation.is_email_verified = is_email_verified + return original_implementation + + return ev_override_functions + + elif eval_str.startswith("thirdparty.init.override.apis"): + from supertokens_python.recipe.thirdparty.interfaces import ( + APIInterface as ThirdPartyAPIInterface, + APIOptions as ThirdPartyAPIOptions, + ) + + def tp_override_apis( + original_implementation: ThirdPartyAPIInterface, + ) -> ThirdPartyAPIInterface: + + og_sign_in_up_post = original_implementation.sign_in_up_post + + async def sign_in_up_post( + provider: Provider, + redirect_uri_info: Optional[RedirectUriInfo], + oauth_tokens: Optional[Dict[str, Any]], + session: Optional[SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + tenant_id: str, + api_options: ThirdPartyAPIOptions, + user_context: Dict[str, Any], + ) -> Union[ + SignInUpPostOkResult, + SignInUpPostNoEmailGivenByProviderResponse, + SignInUpNotAllowed, + GeneralErrorResponse, + ]: + json_body = await api_options.request.json() + if ( + json_body is not None + and json_body.get("userContext", {}).get("DO_LINK") is not None + ): + user_context["DO_LINK"] = json_body["userContext"]["DO_LINK"] + + result = await og_sign_in_up_post( + provider, + redirect_uri_info, + oauth_tokens, + session, + should_try_linking_with_session_user, + tenant_id, + api_options, + user_context, + ) + + if isinstance(result, SignInUpPostOkResult): + global user_in_callback + user_in_callback = result.user + + return result + + original_implementation.sign_in_up_post = sign_in_up_post + return original_implementation + + return tp_override_apis + + elif eval_str.startswith("accountlinking.init.shouldDoAutomaticAccountLinking"): + if "onlyLinkIfNewUserVerified" in eval_str: + + async def func4( + new_user_account: Any, + existing_user: Any, + session: Any, + tenant_id: Any, + user_context: Dict[str, Any], + ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + if user_context.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + + if ( + new_user_account.third_party is not None + and existing_user is not None + ): + if user_context.get("isVerified"): + return ShouldAutomaticallyLink(should_require_verification=True) + return ShouldNotAutomaticallyLink() + + return ShouldAutomaticallyLink(should_require_verification=True) + + return func4 + + async def func( + i: Any, l: Any, o: Any, u: Any, a: Any # pylint: disable=unused-argument + ) -> Union[ShouldNotAutomaticallyLink, ShouldAutomaticallyLink]: + if ( + "()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!1})" + in eval_str + ): + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,u,a)=>a.DO_LINK?{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}:{shouldAutomaticallyLink:!1}" + in eval_str + ): + if a.get("DO_LINK"): + return ShouldAutomaticallyLink(should_require_verification=True) + return ShouldNotAutomaticallyLink() + + if ( + "(i,l,o,u,a)=>a.DO_NOT_LINK?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,u,a)=>a.DO_NOT_LINK?{shouldAutomaticallyLink:!1}:a.DO_LINK_WITHOUT_VERIFICATION?{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if a.get("DO_LINK_WITHOUT_VERIFICATION"): + return ShouldAutomaticallyLink(should_require_verification=False) + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + '(i,l,o,a,e)=>e.DO_NOT_LINK||"test2@example.com"===i.email&&void 0===l?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}' + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if i.email == "test2@example.com" and l is None: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,d,t)=>t.DO_NOT_LINK||void 0!==l&&l.id===o.getUserId()?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!1}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if l is not None and l.id == o.get_user_id(): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=False) + + if ( + "(i,l,o,d,t)=>t.DO_NOT_LINK||void 0!==l&&l.id===o.getUserId()?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}" + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if l is not None and l.id == o.get_user_id(): + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + '(i,l,o,a,e)=>e.DO_NOT_LINK||"test2@example.com"===i.email&&void 0===l?{shouldAutomaticallyLink:!1}:{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}' + in eval_str + ): + if a.get("DO_NOT_LINK"): + return ShouldNotAutomaticallyLink() + if i.email == "test2@example.com" and l is None: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + 'async(i,e)=>{if("emailpassword"===i.recipeId){if(!((await supertokens.listUsersByAccountInfo("public",{email:i.email})).length>1))return{shouldAutomaticallyLink:!1}}return{shouldAutomaticallyLink:!0,shouldRequireVerification:!0}}' + in eval_str + ): + if i.recipe_id == "emailpassword": + users = await list_users_by_account_info( + "public", AccountInfo(email=i.email) + ) + if len(users) <= 1: + return ShouldNotAutomaticallyLink() + return ShouldAutomaticallyLink(should_require_verification=True) + + if ( + "async()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!0})" + in eval_str + or "()=>({shouldAutomaticallyLink:!0,shouldRequireVerification:!0})" + in eval_str + ): + return ShouldAutomaticallyLink(should_require_verification=True) + + return ShouldNotAutomaticallyLink() + + return func + + if eval_str.startswith("thirdparty.init.signInAndUpFeature.providers"): + + def custom_provider(provider: Any): + if "custom-ev" in eval_str: + + async def exchange_auth_code_for_oauth_tokens1( + redirect_uri_info: RedirectUriInfo, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info.redirect_uri_query_params + + async def get_user_info1( + oauth_tokens: Any, + user_context: Any, # pylint: disable=unused-argument + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens1 + ) + provider.get_user_info = get_user_info1 + return provider + + if "custom-no-ev" in eval_str: + + async def exchange_auth_code_for_oauth_tokens2( + redirect_uri_info: RedirectUriInfo, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info.redirect_uri_query_params + + async def get_user_info2( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=False, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens2 + ) + provider.get_user_info = get_user_info2 + return provider + + if "custom2" in eval_str: + + async def exchange_auth_code_for_oauth_tokens3( + redirect_uri_info: RedirectUriInfo, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info.redirect_uri_query_params + + async def get_user_info3( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=f"custom2{oauth_tokens['email']}", + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens3 + ) + provider.get_user_info = get_user_info3 + return provider + + if "custom3" in eval_str: + + async def exchange_auth_code_for_oauth_tokens4( + redirect_uri_info: RedirectUriInfo, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info.redirect_uri_query_params + + async def get_user_info4( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens["email"], + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens4 + ) + provider.get_user_info = get_user_info4 + return provider + + if "custom" in eval_str: + + async def exchange_auth_code_for_oauth_tokens5( + redirect_uri_info: RedirectUriInfo, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info.redirect_uri_query_params + + async def get_user_info5( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + if oauth_tokens.get("error"): + raise Exception("Credentials error") + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "userId"), + email=( + None + if oauth_tokens.get("email") is None + else UserInfoEmail( + email=oauth_tokens.get("email"), + is_verified=oauth_tokens.get("isVerified", False), + ) + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens5 + ) + provider.get_user_info = get_user_info5 + return provider + + return custom_provider + + if eval_str.startswith("emailverification.init.getEmailForRecipeUserId"): + from supertokens_python.recipe.emailverification.interfaces import ( + UnknownUserIdError as EVUnknownUserId, + ) + + async def get_email_for_recipe_user_id( + recipe_user_id: RecipeUserId, + user_context: Dict[str, Any], + ) -> Union[GetEmailForUserIdOkResult, EmailDoesNotExistError, EVUnknownUserId]: + if "random@example.com" in eval_str: + return GetEmailForUserIdOkResult(email="random@example.com") + + if ( + hasattr(recipe_user_id, "get_as_string") + and recipe_user_id.get_as_string() == "random" + ): + return GetEmailForUserIdOkResult(email="test@example.com") + + return EVUnknownUserId() + + return get_email_for_recipe_user_id + + raise Exception("Unknown eval string: " + eval_str) + + +class OverrideParams(APIResponse): + def __init__( + self, + send_email_to_user_id: Optional[str] = None, + token: Optional[str] = None, + user_post_password_reset: Optional[User] = None, + email_post_password_reset: Optional[str] = None, + send_email_callback_called: Optional[bool] = None, + send_email_to_user_email: Optional[str] = None, + send_email_inputs: Optional[List[str]] = None, + send_sms_inputs: List[str] = [], # pylint: disable=dangerous-default-value + send_email_to_recipe_user_id: Optional[str] = None, + user_in_callback: Optional[ + Union[User, VerificationEmailTemplateVarsUser] + ] = None, + email: Optional[str] = None, + new_account_info_in_callback: Optional[RecipeLevelUser] = None, + primary_user_in_callback: Optional[User] = None, + user_id_in_callback: Optional[str] = None, + recipe_user_id_in_callback: Optional[str] = None, + core_call_count: int = 0, + store: Dict[str, Any] = {}, # pylint: disable=dangerous-default-value + ): + self.send_email_to_user_id = send_email_to_user_id + self.token = token + self.user_post_password_reset = user_post_password_reset + self.email_post_password_reset = email_post_password_reset + self.send_email_callback_called = send_email_callback_called + self.send_email_to_user_email = send_email_to_user_email + self.send_email_inputs = send_email_inputs + self.send_sms_inputs = send_sms_inputs + self.send_email_to_recipe_user_id = send_email_to_recipe_user_id + self.user_in_callback = user_in_callback + self.email = email + self.new_account_info_in_callback = new_account_info_in_callback + self.primary_user_in_callback = primary_user_in_callback + self.user_id_in_callback = user_id_in_callback + self.recipe_user_id_in_callback = recipe_user_id_in_callback + self.core_call_count = core_call_count + self.store = store + + def to_json(self) -> Dict[str, Any]: + respon_json = { + "sendEmailToUserId": self.send_email_to_user_id, + "token": self.token, + "userPostPasswordReset": ( + self.user_post_password_reset.to_json() + if self.user_post_password_reset is not None + else None + ), + "emailPostPasswordReset": self.email_post_password_reset, + "sendEmailCallbackCalled": self.send_email_callback_called, + "sendEmailToUserEmail": self.send_email_to_user_email, + "sendEmailInputs": self.send_email_inputs, + "sendSmsInputs": self.send_sms_inputs, + "sendEmailToRecipeUserId": ( + # this is intentionally done this way cause the test in the test suite expects this way. + {"recipeUserId": self.send_email_to_recipe_user_id} + if self.send_email_to_recipe_user_id is not None + else None + ), + "userInCallback": ( + self.user_in_callback.to_json() + if self.user_in_callback is not None + else None + ), + "email": self.email, + "newAccountInfoInCallback": ( + self.new_account_info_in_callback.to_json() + if self.new_account_info_in_callback is not None + else None + ), + "primaryUserInCallback": ( + self.primary_user_in_callback.to_json() + if self.primary_user_in_callback is not None + else None + ), + "userIdInCallback": self.user_id_in_callback, + "recipeUserIdInCallback": self.recipe_user_id_in_callback, + "info": { + "coreCallCount": self.core_call_count, + }, + "store": self.store, + } + # Filter out items that are None + respon_json = {k: v for k, v in respon_json.items() if v is not None} + return respon_json + + +def get_override_params() -> OverrideParams: + return OverrideParams( + send_email_to_user_id=send_email_to_user_id, + token=token, + user_post_password_reset=user_post_password_reset, + email_post_password_reset=email_post_password_reset, + send_email_callback_called=send_email_callback_called, + send_email_to_user_email=send_email_to_user_email, + send_email_inputs=send_email_inputs, + send_sms_inputs=send_sms_inputs, + send_email_to_recipe_user_id=send_email_to_recipe_user_id, + user_in_callback=user_in_callback, + email=email_param, + new_account_info_in_callback=new_account_info_in_callback, + primary_user_in_callback=( + primary_user_in_callback if primary_user_in_callback else None + ), + user_id_in_callback=user_id_in_callback, + recipe_user_id_in_callback=( + recipe_user_id_in_callback.get_as_string() + if isinstance(recipe_user_id_in_callback, RecipeUserId) + else None + ), + core_call_count=Info.core_call_count, + store=store, + ) + + +def reset_override_params(): + global send_email_to_user_id, token, user_post_password_reset, email_post_password_reset, send_email_callback_called, send_email_to_user_email, send_email_inputs, send_sms_inputs, send_email_to_recipe_user_id, user_in_callback, email_param, primary_user_in_callback, new_account_info_in_callback, user_id_in_callback, recipe_user_id_in_callback, store + send_email_to_user_id = None + token = None + user_post_password_reset = None + email_post_password_reset = None + send_email_callback_called = False + send_email_to_user_email = None + send_email_inputs = [] + send_sms_inputs = [] + send_email_to_recipe_user_id = None + user_in_callback = None + email_param = None + primary_user_in_callback = None + new_account_info_in_callback = None + user_id_in_callback = None + recipe_user_id_in_callback = None + store = {} + Info.core_call_count = 0 + + +send_email_to_user_id: Optional[str] = None +token: Optional[str] = None +user_post_password_reset: Optional[User] = None +email_post_password_reset: Optional[str] = None +send_email_callback_called: bool = False +send_email_to_user_email: Optional[str] = None +send_email_inputs: List[Any] = [] +send_sms_inputs: List[Any] = [] +send_email_to_recipe_user_id: Optional[str] = None +user_in_callback: Optional[Union[User, VerificationEmailTemplateVarsUser]] = None +email_param: Optional[str] = None +primary_user_in_callback: Optional[User] = None +new_account_info_in_callback: Optional[RecipeLevelUser] = None +user_id_in_callback: Optional[str] = None +recipe_user_id_in_callback: Optional[RecipeUserId] = None +store: Dict[str, Any] = {} diff --git a/tests/test-server/thirdparty.py b/tests/test-server/thirdparty.py new file mode 100644 index 000000000..513fd5a6b --- /dev/null +++ b/tests/test-server/thirdparty.py @@ -0,0 +1,90 @@ +from flask import Flask, request, jsonify + +from session import convert_session_to_container # pylint: disable=import-error +from supertokens_python.recipe.thirdparty.interfaces import ( + EmailChangeNotAllowedError, + ManuallyCreateOrUpdateUserOkResult, + SignInUpNotAllowed, +) +from supertokens_python.recipe.thirdparty.syncio import manually_create_or_update_user +from utils import ( # pylint: disable=import-error + serialize_user, + serialize_recipe_user_id, +) # pylint: disable=import-error + + +def add_thirdparty_routes(app: Flask): + @app.route("/test/thirdparty/manuallycreateorupdateuser", methods=["POST"]) # type: ignore + def thirdpartymanuallycreateorupdate(): # type: ignore + data = request.json + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + third_party_id = data["thirdPartyId"] + third_party_user_id = data["thirdPartyUserId"] + email = data["email"] + is_verified = data["isVerified"] + user_context = data.get("userContext", {}) + + session = None + if data.get("session"): + session = convert_session_to_container(data["session"]) + + response = manually_create_or_update_user( + tenant_id, + third_party_id, + third_party_user_id, + email, + is_verified, + session, + user_context, + ) + + if isinstance(response, ManuallyCreateOrUpdateUserOkResult): + return jsonify( + { + "status": "OK", + **serialize_user( + response.user, request.headers.get("fdi-version", "") + ), + **serialize_recipe_user_id( + response.recipe_user_id, request.headers.get("fdi-version", "") + ), + } + ) + elif isinstance(response, EmailChangeNotAllowedError): + return jsonify( + {"status": "EMAIL_CHANGE_NOT_ALLOWED_ERROR", "reason": response.reason} + ) + elif isinstance(response, SignInUpNotAllowed): + return jsonify(response.to_json()) + elif isinstance(response, SignInUpNotAllowed): + return jsonify(response.to_json()) + else: + return jsonify( + { + "status": response.status, + "reason": response.reason, + } + ) + + @app.route("/test/thirdparty/getprovider", methods=["POST"]) # type: ignore + def get_provider(): # type: ignore + data = request.get_json() + if data is None: + return jsonify({"status": "MISSING_DATA_ERROR"}) + + tenant_id = data.get("tenantId", "public") + third_party_id = data["thirdPartyId"] + client_type = data.get("clientType", None) + user_context = data.get("userContext", {}) + + from supertokens_python.recipe.thirdparty.syncio import get_provider + + provider = get_provider(tenant_id, third_party_id, client_type, user_context) + + if provider is None: + return jsonify({}) + + return jsonify({"id": provider.id, "config": provider.config.to_json()}) diff --git a/tests/test-server/totp.py b/tests/test-server/totp.py new file mode 100644 index 000000000..46e8af537 --- /dev/null +++ b/tests/test-server/totp.py @@ -0,0 +1,52 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.totp.syncio import create_device, verify_device +from supertokens_python.recipe.totp.types import ( + CreateDeviceOkResult, +) +from supertokens_python.recipe.totp.types import ( + DeviceAlreadyExistsError, + InvalidTOTPError, + UnknownDeviceError, + VerifyDeviceOkResult, +) + + +def add_totp_routes(app: Flask): + @app.route("/test/totp/createdevice", methods=["POST"]) # type: ignore + def create_device_api(): # type: ignore + assert request.json is not None + body = request.json + response = create_device( + user_id=body.get("userId"), + user_identifier_info=body.get("userIdentifierInfo"), + device_name=body.get("deviceName"), + skew=body.get("skew"), + period=body.get("period"), + user_context=body.get("userContext"), + ) + if isinstance(response, CreateDeviceOkResult): + return jsonify(response.to_json()) + elif isinstance(response, DeviceAlreadyExistsError): + return jsonify(response.to_json()) + else: + return jsonify({"status": "UNKNOWN_USER_ID_ERROR"}) + + @app.route("/test/totp/verifydevice", methods=["POST"]) # type: ignore + def verify_device_api(): # type: ignore + assert request.json is not None + body = request.json + response = verify_device( + tenant_id=body.get("tenantId"), + user_id=body.get("userId"), + device_name=body.get("deviceName"), + totp=body.get("totp"), + user_context=body.get("userContext"), + ) + if isinstance(response, VerifyDeviceOkResult): + return jsonify(response.to_json()) + elif isinstance(response, UnknownDeviceError): + return jsonify(response.to_json()) + elif isinstance(response, InvalidTOTPError): + return jsonify(response.to_json()) + else: + return jsonify(response.to_json()) diff --git a/tests/test-server/usermetadata.py b/tests/test-server/usermetadata.py new file mode 100644 index 000000000..a4a960774 --- /dev/null +++ b/tests/test-server/usermetadata.py @@ -0,0 +1,40 @@ +from flask import Flask, request, jsonify + +from supertokens_python.recipe.usermetadata.syncio import ( + get_user_metadata, + update_user_metadata, + clear_user_metadata, +) + + +def add_usermetadata_routes(app: Flask): + @app.route("/test/usermetadata/getusermetadata", methods=["POST"]) # type: ignore + def get_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + response = get_user_metadata( + user_id=user_id, user_context=request.json.get("userContext") + ) + return jsonify({"metadata": response.metadata}) + + @app.route("/test/usermetadata/updateusermetadata", methods=["POST"]) # type: ignore + def update_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + metadata_update = request.json["metadataUpdate"] + + response = update_user_metadata( + user_id=user_id, + metadata_update=metadata_update, + user_context=request.json.get("userContext"), + ) + return jsonify({"metadata": response.metadata}) + + @app.route("/test/usermetadata/clearusermetadata", methods=["POST"]) # type: ignore + def clear_user_metadata_api(): # type: ignore + assert request.json is not None + user_id = request.json["userId"] + clear_user_metadata( + user_id=user_id, user_context=request.json.get("userContext") + ) + return jsonify({"status": "OK"}) diff --git a/tests/test-server/utils.py b/tests/test-server/utils.py index 3ccc84a4b..1d70e8035 100644 --- a/tests/test-server/utils.py +++ b/tests/test-server/utils.py @@ -1,25 +1,85 @@ -from typing import Any, Dict +from supertokens_python.recipe.multifactorauth.multi_factor_auth_claim import ( + MultiFactorAuthClaim, +) from supertokens_python.recipe.session.claims import SessionClaim from supertokens_python.recipe.session.interfaces import SessionClaimValidator +from supertokens_python.types import RecipeUserId, User +from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session.claims import BooleanClaim +from supertokens_python.recipe.emailverification import EmailVerificationClaim +from supertokens_python.recipe.userroles import UserRoleClaim +from supertokens_python.recipe.userroles import PermissionClaim +from typing import Any, Dict +from supertokens_python.recipe.session.claims import PrimitiveClaim + + +def mock_claim_builder(key: str, values: Any) -> PrimitiveClaim[Any]: + def fetch_value( + user_id: str, + recipe_user_id: RecipeUserId, + tenant_id: str, + current_payload: Dict[str, Any], + user_context: Dict[str, Any], + ) -> Any: + log_override_event( + f"claim-{key}.fetchValue", + "CALL", + { + "userId": user_id, + "recipeUserId": recipe_user_id.get_as_string(), + "tenantId": tenant_id, + "currentPayload": current_payload, + "userContext": user_context, + }, + ) + + ret_val: Any = user_context.get("st-stub-arr-value") or ( + values[0] + if isinstance(values, list) and isinstance(values[0], list) + else values + ) + log_override_event(f"claim-{key}.fetchValue", "RES", ret_val) -test_claims: Dict[str, SessionClaim] = {} # type: ignore + return ret_val + return PrimitiveClaim(key=key or "st-stub-primitive", fetch_value=fetch_value) -def init_test_claims(): - add_builtin_claims() +test_claim_setups: Dict[str, SessionClaim[Any]] = { + "st-true": BooleanClaim( + key="st-true", + fetch_value=lambda *_args, **_kwargs: True, # type: ignore + ), + "st-undef": BooleanClaim( + key="st-undef", + fetch_value=lambda *_args, **_kwargs: None, # type: ignore + ), +} -def add_builtin_claims(): - from supertokens_python.recipe.emailverification import EmailVerificationClaim +# Add all built-in claims +for claim in [ + EmailVerificationClaim, + MultiFactorAuthClaim, + UserRoleClaim, + PermissionClaim, +]: + test_claim_setups[claim.key] = claim # type: ignore - test_claims[EmailVerificationClaim.key] = EmailVerificationClaim + +def deserialize_claim(serialized_claim: Dict[str, Any]) -> SessionClaim[Any]: + key = serialized_claim["key"] + + if key.startswith("st-stub-"): + return mock_claim_builder(key.replace("st-stub-", "", 1), serialized_claim) + + return test_claim_setups[key] def deserialize_validator(validatorsInput: Any) -> SessionClaimValidator: # type: ignore key = validatorsInput["key"] - if key in test_claims: - claim = test_claims[key] # type: ignore + if key in test_claim_setups: + claim = test_claim_setups[key] validator_name = validatorsInput["validatorName"] if hasattr(claim.validators, toSnakeCase(validator_name)): # type: ignore validator_func = getattr(claim.validators, toSnakeCase(validator_name)) # type: ignore @@ -39,3 +99,49 @@ def toSnakeCase(camel_case: str) -> str: else: result += char return result + + +def get_max_version(v1: str, v2: str) -> str: + v1_split = v1.split(".") + v2_split = v2.split(".") + max_loop = min(len(v1_split), len(v2_split)) + + for i in range(max_loop): + if int(v1_split[i]) > int(v2_split[i]): + return v1 + if int(v2_split[i]) > int(v1_split[i]): + return v2 + + if len(v1_split) > len(v2_split): + return v1 + + return v2 + + +def serialize_user(user: User, fdi_version: str) -> Dict[str, Any]: + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + return { + "user": { + "id": user.id, + "email": user.emails[0], + "timeJoined": user.time_joined, + "tenantIds": user.tenant_ids, + } + } + else: + return {"user": user.to_json()} + + +def serialize_recipe_user_id( + recipe_user_id: RecipeUserId, fdi_version: str +) -> Dict[str, Any]: + if get_max_version("1.17", fdi_version) == "1.17" or ( + get_max_version("2.0", fdi_version) == fdi_version + and get_max_version("3.0", fdi_version) != fdi_version + ): + return {} + else: + return {"recipeUserId": recipe_user_id.get_as_string()} diff --git a/tests/test_config.py b/tests/test_config.py index 1a8b30963..de831e405 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -21,6 +21,7 @@ from supertokens_python.recipe.session.asyncio import create_new_session from typing import Optional, Dict, Any from supertokens_python.framework import BaseRequest +from supertokens_python.types import RecipeUserId from tests.utils import clean_st, reset, setup_st, start_st @@ -698,7 +699,7 @@ async def test_samesite_valid_config(): ["http://localhost:3000", "https://supertokensapi.io"], ["http://127.0.0.1:3000", "https://supertokensapi.io"], ] - for (website_domain, api_domain) in domain_combinations: + for website_domain, api_domain in domain_combinations: reset() clean_st() setup_st() @@ -717,6 +718,10 @@ async def test_samesite_valid_config(): @mark.asyncio async def test_samesite_invalid_config(): + reset() + clean_st() + setup_st() + start_st() domain_combinations = [ ["http://localhost:3000", "http://supertokensapi.io"], ["http://127.0.0.1:3000", "http://supertokensapi.io"], @@ -724,10 +729,8 @@ async def test_samesite_invalid_config(): ["http://supertokens.io", "http://127.0.0.1:8000"], ["http://supertokens.io", "http://supertokensapi.io"], ] - for (website_domain, api_domain) in domain_combinations: - reset() - clean_st() - setup_st() + for website_domain, api_domain in domain_combinations: + reset(False) try: init( supertokens_config=SupertokensConfig("http://localhost:3567"), @@ -744,7 +747,9 @@ async def test_samesite_invalid_config(): ) ], ) - await create_new_session("public", MagicMock(), "userId", {}, {}) + await create_new_session( + MagicMock(), "public", RecipeUserId("userId"), {}, {} + ) except Exception as e: assert ( str(e) diff --git a/tests/test_config_without_core.py b/tests/test_config_without_core.py index 5cbe7e28c..4d3b2a7b4 100644 --- a/tests/test_config_without_core.py +++ b/tests/test_config_without_core.py @@ -1,9 +1,9 @@ from pytest import mark -from supertokens_python import InputAppInfo, Supertokens, SupertokensConfig, init +from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.recipe import session from supertokens_python.recipe.session import SessionRecipe -from supertokens_python.recipe.multitenancy.recipe import MultitenancyRecipe +from tests.utils import reset @mark.parametrize( @@ -31,9 +31,7 @@ def test_same_site_cookie_values( api_domain: str, website_domain: str, cookie_same_site: str ): - Supertokens.reset() - SessionRecipe.reset() - MultitenancyRecipe.reset() + reset() init( supertokens_config=SupertokensConfig("http://localhost:3567"), @@ -53,6 +51,3 @@ def test_same_site_cookie_values( s = SessionRecipe.get_instance() assert s.config.get_cookie_same_site(None, {}) == cookie_same_site - SessionRecipe.reset() - MultitenancyRecipe.reset() - Supertokens.reset() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index c5c5b8003..b7911ee4c 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -190,7 +190,8 @@ async def test_wrong_rid_with_existing_api_works( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"][0] == user_info["emails"][0] + assert len(dict_response["user"]["emails"]) == 1 @mark.asyncio @@ -239,7 +240,8 @@ async def test_random_rid_with_existing_api_works( assert response_2.status_code == 200 dict_response = json.loads(response_2.text) assert dict_response["user"]["id"] == user_info["id"] - assert dict_response["user"]["email"] == user_info["email"] + assert dict_response["user"]["emails"][0] == user_info["emails"][0] + assert len(dict_response["user"]["emails"]) == 1 @mark.asyncio diff --git a/tests/test_pagination.py b/tests/test_pagination.py index d66b9163b..fb1249e90 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -51,16 +51,16 @@ async def test_get_users_pagination(): # Get all the users (No limit) response = await get_users_newest_first("public") - assert [user.email for user in response.users] == [ + assert [user.emails[0] for user in response.users] == [ f"dummy{i}@gmail.com" for i in range(5) ][::-1] # Get only the oldest user response = await get_users_oldest_first("public", limit=1) - assert [user.email for user in response.users] == ["dummy0@gmail.com"] + assert [user.emails[0] for user in response.users] == ["dummy0@gmail.com"] # Test pagination response = await get_users_oldest_first( "public", limit=1, pagination_token=response.next_pagination_token ) - assert [user.email for user in response.users] == ["dummy1@gmail.com"] + assert [user.emails[0] for user in response.users] == ["dummy1@gmail.com"] diff --git a/tests/test_passwordless.py b/tests/test_passwordless.py index 309dffdcb..db271ab79 100644 --- a/tests/test_passwordless.py +++ b/tests/test_passwordless.py @@ -15,6 +15,12 @@ from typing import Any, Dict from fastapi import FastAPI +from supertokens_python.asyncio import get_user, list_users_by_account_info +from supertokens_python.recipe.passwordless.asyncio import ( + delete_email_for_user, + delete_phone_number_for_user, +) +from supertokens_python.types import AccountInfo, RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark, raises, skip from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -22,16 +28,10 @@ from supertokens_python.querier import Querier from supertokens_python.recipe import passwordless, session from supertokens_python.recipe.passwordless.asyncio import ( - delete_email_for_user, - delete_phone_number_for_user, - get_user_by_email, - get_user_by_id, - get_user_by_phone_number, update_user, create_magic_link, ) from supertokens_python.recipe.passwordless.interfaces import ( - DeleteUserInfoOkResult, UpdateUserOkResult, ) from supertokens_python.utils import is_version_gte @@ -115,12 +115,30 @@ async def send_sms( ).json() consume_code_json["user"].pop("id") - consume_code_json["user"].pop("time_joined") + consume_code_json["user"].pop("timeJoined") + consume_code_json["user"]["loginMethods"][0].pop("recipeUserId") + consume_code_json["user"]["loginMethods"][0].pop("timeJoined") assert consume_code_json == { "status": "OK", - "createdNewUser": True, - "user": {"phoneNumber": "+919494949494"}, + "createdNewRecipeUser": True, + "user": { + "isPrimaryUser": False, + "tenantIds": ["public"], + "emails": [], + "phoneNumbers": ["+919494949494"], + "thirdParty": [], + "loginMethods": [ + { + "recipeId": "passwordless", + "tenantIds": ["public"], + "email": None, + "phoneNumber": "+919494949494", + "thirdParty": None, + "verified": True, + } + ], + }, } @@ -225,16 +243,18 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - await update_user(user_id, "foo@example.com", "+919494949494") + await update_user(RecipeUserId(user_id), "foo@example.com", "+919494949494") - response = await delete_phone_number_for_user(user_id) - assert isinstance(response, DeleteUserInfoOkResult) + response = await delete_phone_number_for_user(RecipeUserId(user_id)) + assert isinstance(response, UpdateUserOkResult) - user = await get_user_by_phone_number("public", "+919494949494") - assert user is None + user = await list_users_by_account_info( + "public", AccountInfo(phone_number="+919494949494") + ) + assert len(user) == 0 - user = await get_user_by_id(user_id) - assert user is not None and user.phone_number is None + user = await get_user(user_id) + assert user is not None and user.phone_numbers == [] @mark.asyncio @@ -308,16 +328,18 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - await update_user(user_id, "hello@example.com", "+919494949494") + await update_user(RecipeUserId(user_id), "hello@example.com", "+919494949494") - response = await delete_email_for_user(user_id) - assert isinstance(response, DeleteUserInfoOkResult) + response = await delete_email_for_user(RecipeUserId(user_id)) + assert isinstance(response, UpdateUserOkResult) - user = await get_user_by_email("public", "hello@example.com") - assert user is None + user = await list_users_by_account_info( + "public", AccountInfo(email="hello@example.com") + ) + assert len(user) == 0 - user = await get_user_by_id(user_id) - assert user is not None and user.email is None + user = await get_user(user_id) + assert user is not None and user.emails == [] @mark.asyncio @@ -393,14 +415,16 @@ async def send_sms( user_id = consume_code_json["user"]["id"] - response = await update_user(user_id, "hello@example.com", "+919494949494") + response = await update_user( + RecipeUserId(user_id), "hello@example.com", "+919494949494" + ) assert isinstance(response, UpdateUserOkResult) # Delete the email - response = await delete_email_for_user(user_id) + response = await delete_email_for_user(RecipeUserId(user_id)) # Delete the phone number (Should raise exception because deleting both of them isn't allowed) with raises(Exception) as e: - response = await delete_phone_number_for_user(user_id) + response = await delete_phone_number_for_user(RecipeUserId(user_id)) assert e.value.args[0].endswith( "You cannot clear both email and phone number of a user\n" diff --git a/tests/test_querier.py b/tests/test_querier.py index 2b4746a14..0af1d19e7 100644 --- a/tests/test_querier.py +++ b/tests/test_querier.py @@ -20,10 +20,7 @@ thirdparty, ) from supertokens_python import InputAppInfo -from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id, sign_up -from supertokens_python.recipe.thirdparty.asyncio import ( - get_user_by_id as tp_get_user_by_id, -) +from supertokens_python.recipe.emailpassword.asyncio import get_user, sign_up import asyncio import respx import httpx @@ -236,29 +233,29 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await tp_get_user_by_id("random", user_context) + user = await get_user("random2", user_context) assert user is None assert called_core called_core = False - user = await tp_get_user_by_id("random", user_context) + user = await get_user("random2", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core @@ -299,22 +296,22 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core @@ -357,14 +354,14 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -407,20 +404,20 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core called_core = False - user = await tp_get_user_by_id("random", user_context) + user = await get_user("random2", user_context) assert user is None assert called_core @@ -461,7 +458,7 @@ def intercept( ) # type: ignore start_st() user_context: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -470,7 +467,7 @@ def intercept( called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core @@ -513,33 +510,33 @@ def intercept( user_context: Dict[str, Any] = {"_default": {"keep_cache_alive": True}} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert not called_core @@ -583,33 +580,33 @@ def intercept( user_context: Dict[str, Any] = {"_default": {"keep_cache_alive": False}} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core @@ -653,33 +650,33 @@ def intercept( user_context: Dict[str, Any] = {} user_context_2: Dict[str, Any] = {} - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core - await sign_up("public", "test@example.com", "abcd1234", user_context) + await sign_up("public", "test@example.com", "abcd1234", None, user_context) called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert called_core called_core = False - user = await get_user_by_id("random", user_context) + user = await get_user("random", user_context) assert user is None assert not called_core - user = await get_user_by_id("random", user_context_2) + user = await get_user("random", user_context_2) assert user is None assert called_core diff --git a/tests/test_session.py b/tests/test_session.py index 38cc901da..1f88a1417 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -21,13 +21,14 @@ from fastapi.requests import Request from fastapi.responses import JSONResponse from pytest import fixture, mark +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from requests.cookies import cookiejar_from_dict # type: ignore from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.framework import BaseRequest from supertokens_python.framework.fastapi.fastapi_middleware import get_middleware -from supertokens_python.process_state import AllowedProcessStates, ProcessState +from supertokens_python.process_state import PROCESS_STATE, ProcessState from supertokens_python.recipe import session from supertokens_python.recipe.session import InputOverrideConfig, SessionRecipe from supertokens_python.recipe.session.asyncio import ( @@ -105,7 +106,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", "", False, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(""), False, {}, {}, None ) assert response.session is not None @@ -119,7 +120,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): s.recipe_implementation, access_token, response.antiCsrfToken, True, False, None ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -151,8 +152,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - in ProcessState.get_instance().history + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY in ProcessState.get_instance().history ) assert response3.session is not None @@ -173,7 +173,7 @@ async def test_that_once_the_info_is_loaded_it_doesnt_query_again(): None, ) assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -211,7 +211,7 @@ async def test_creating_many_sessions_for_one_user_and_looping(): new_session = await create_new_session( s.recipe_implementation, "public", - "someUser", + RecipeUserId("someUser"), False, {"someKey": "someValue"}, {}, @@ -219,7 +219,7 @@ async def test_creating_many_sessions_for_one_user_and_looping(): ) access_tokens.append(new_session.accessToken.token) - session_handles = await get_all_session_handles_for_user("someUser", "public") + session_handles = await get_all_session_handles_for_user("someUser", True, "public") assert len(session_handles) == 7 @@ -281,7 +281,9 @@ async def home(_request: Request): # type: ignore @app.post("/create") async def create_api(request: Request): # type: ignore - await async_create_new_session(request, "public", "test-user", {}, {}) + await async_create_new_session( + request, "public", RecipeUserId("test-user"), {}, {} + ) return "" @app.post("/sessioninfo-optional") @@ -314,7 +316,7 @@ async def test_signout_api_works_even_if_session_is_deleted_after_creation( user_id = "user_id" response = await create_new_session( - s.recipe_implementation, "public", user_id, False, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), False, {}, {}, None ) session_handle = response.session.handle @@ -401,7 +403,9 @@ async def get_session_information( mock_response = MagicMock() - my_session = await async_create_new_session(mock_response, "public", "test_id") + my_session = await async_create_new_session( + mock_response, "public", RecipeUserId("test_id"), {} + ) data = await my_session.get_session_data_from_database() assert data == {"foo": "bar"} @@ -734,7 +738,9 @@ async def test_that_verify_session_doesnt_always_call_core(): # response = await create_new_session(s.recipe_implementation, "", False, {}, {}) - session1 = await create_new_session_without_request_response("public", "user-id") + session1 = await create_new_session_without_request_response( + "public", RecipeUserId("user-id") + ) assert session1 is not None assert session1.access_token != "" @@ -742,7 +748,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session1.refresh_token is not None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -756,7 +762,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session2.refresh_token is None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -770,7 +776,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session3.refresh_token is not None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY not in ProcessState.get_instance().history ) @@ -784,8 +790,7 @@ async def test_that_verify_session_doesnt_always_call_core(): assert session4.refresh_token is None assert ( - AllowedProcessStates.CALLING_SERVICE_IN_VERIFY - in ProcessState.get_instance().history + PROCESS_STATE.CALLING_SERVICE_IN_VERIFY in ProcessState.get_instance().history ) # Core got called this time diff --git a/tests/test_user_context.py b/tests/test_user_context.py index 5b51fd08d..039843f0d 100644 --- a/tests/test_user_context.py +++ b/tests/test_user_context.py @@ -11,10 +11,11 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pathlib import Path from fastapi import FastAPI +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark @@ -73,12 +74,19 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): user_context = {"preSignInPOST": True} response = await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) if ( "preSignInPOST" in user_context @@ -99,20 +107,44 @@ def functions_override_email_password(param: RecipeInterface): og_sign_up = param.sign_up async def sign_up_( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], ): if "manualCall" in user_context: global signUpContextWorks signUpContextWorks = True - response = await og_sign_up(email, password, tenant_id, user_context) + response = await og_sign_up( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) return response async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], ): if "preSignInPOST" in user_context: user_context["preSignIn"] = True - response = await og_sign_in(email, password, tenant_id, user_context) + response = await og_sign_in( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) if "preSignInPOST" in user_context and "preSignIn" in user_context: user_context["postSignIn"] = True return response @@ -126,6 +158,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -140,6 +173,7 @@ async def create_new_session( user_context["preCreateNewSession"] = True response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -182,7 +216,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 @@ -204,6 +240,8 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -213,7 +251,12 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) param.sign_in_post = sign_in_post @@ -223,14 +266,26 @@ def functions_override_email_password(param: RecipeInterface): og_sign_in = param.sign_in async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], ): req = user_context.get("_default", {}).get("request") if req: nonlocal signin_context_works signin_context_works = True - return await og_sign_in(email, password, tenant_id, user_context) + return await og_sign_in( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) param.sign_in = sign_in return param @@ -240,6 +295,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -253,6 +309,7 @@ async def create_new_session( response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -288,7 +345,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 @@ -315,6 +374,8 @@ def apis_override_email_password(param: APIInterface): async def sign_in_post( form_fields: List[FormField], tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], api_options: APIOptions, user_context: Dict[str, Any], ): @@ -326,7 +387,12 @@ async def sign_in_post( signin_api_context_works = True return await og_sign_in_post( - form_fields, tenant_id, api_options, user_context + form_fields, + tenant_id, + session, + should_try_linking_with_session_user, + api_options, + user_context, ) param.sign_in_post = sign_in_post @@ -336,7 +402,12 @@ def functions_override_email_password(param: RecipeInterface): og_sign_in = param.sign_in async def sign_in( - email: str, password: str, tenant_id: str, user_context: Dict[str, Any] + email: str, + password: str, + tenant_id: str, + session: Optional[session.SessionContainer], + should_try_linking_with_session_user: Union[bool, None], + user_context: Dict[str, Any], ): req = get_request_from_user_context(user_context) if req: @@ -353,7 +424,14 @@ async def sign_in( user_context["_default"]["request"] = orginal_request - return await og_sign_in(email, password, tenant_id, user_context) + return await og_sign_in( + email, + password, + tenant_id, + session, + should_try_linking_with_session_user, + user_context, + ) param.sign_in = sign_in return param @@ -363,6 +441,7 @@ def functions_override_session(param: SRecipeInterface): async def create_new_session( user_id: str, + recipe_user_id: RecipeUserId, access_token_payload: Optional[Dict[str, Any]], session_data_in_database: Optional[Dict[str, Any]], disable_anti_csrf: Optional[bool], @@ -378,6 +457,7 @@ async def create_new_session( response = await og_create_new_session( user_id, + recipe_user_id, access_token_payload, session_data_in_database, disable_anti_csrf, @@ -413,7 +493,9 @@ async def create_new_session( ) start_st() - await sign_up("public", "random@gmail.com", "validpass123", {"manualCall": True}) + await sign_up( + "public", "random@gmail.com", "validpass123", None, {"manualCall": True} + ) res = sign_in_request(driver_config_client, "random@gmail.com", "validpass123") assert res.status_code == 200 diff --git a/tests/thirdparty/test_emaildelivery.py b/tests/thirdparty/test_emaildelivery.py index 0ca617ac1..207626613 100644 --- a/tests/thirdparty/test_emaildelivery.py +++ b/tests/thirdparty/test_emaildelivery.py @@ -19,6 +19,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from pytest import fixture, mark @@ -139,16 +140,16 @@ async def test_email_verify_default_backward_compatibility( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -213,16 +214,16 @@ async def test_email_verify_default_backward_compatibility_supress_error( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -303,16 +304,16 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( @@ -381,17 +382,17 @@ async def send_email( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id assert isinstance(user_id, str) response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -521,17 +522,17 @@ async def send_email_override( start_st() resp = await manually_create_or_update_user( - "public", "supertokens", "test-user-id", "test@example.com" + "public", "supertokens", "test-user-id", "test@example.com", False, None ) s = SessionRecipe.get_instance() if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") assert isinstance(resp, ManuallyCreateOrUpdateUserOkResult) - user_id = resp.user.user_id + user_id = resp.user.id assert isinstance(user_id, str) response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( diff --git a/tests/thirdparty/test_multitenancy.py b/tests/thirdparty/test_multitenancy.py index 8c205f04d..33666c845 100644 --- a/tests/thirdparty/test_multitenancy.py +++ b/tests/thirdparty/test_multitenancy.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. from pytest import mark +from supertokens_python.asyncio import get_user from supertokens_python.recipe import session, multitenancy, thirdparty from supertokens_python import init from supertokens_python.recipe.multitenancy.asyncio import ( @@ -20,12 +21,17 @@ ) from supertokens_python.recipe.thirdparty.asyncio import ( manually_create_or_update_user, - get_user_by_id, - get_users_by_email, - get_user_by_third_party_info, get_provider, ) -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) +from supertokens_python.recipe.thirdparty.interfaces import ( + ManuallyCreateOrUpdateUserOkResult, +) +from supertokens_python.asyncio import list_users_by_account_info +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import AccountInfo from tests.utils import get_st_init_args from tests.utils import ( @@ -49,29 +55,41 @@ async def test_thirtyparty_multitenancy_functions(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) # sign up: user1a = await manually_create_or_update_user( - "t1", "google", "googleid1", "test@example.com" + "t1", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user1a, ManuallyCreateOrUpdateUserOkResult) user1b = await manually_create_or_update_user( - "t1", "facebook", "fbid1", "test@example.com" + "t1", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user1b, ManuallyCreateOrUpdateUserOkResult) user2a = await manually_create_or_update_user( - "t2", "google", "googleid1", "test@example.com" + "t2", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user2a, ManuallyCreateOrUpdateUserOkResult) user2b = await manually_create_or_update_user( - "t2", "facebook", "fbid1", "test@example.com" + "t2", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user2b, ManuallyCreateOrUpdateUserOkResult) user3a = await manually_create_or_update_user( - "t3", "google", "googleid1", "test@example.com" + "t3", "google", "googleid1", "test@example.com", True, None ) + assert isinstance(user3a, ManuallyCreateOrUpdateUserOkResult) user3b = await manually_create_or_update_user( - "t3", "facebook", "fbid1", "test@example.com" + "t3", "facebook", "fbid1", "test@example.com", True, None ) + assert isinstance(user3b, ManuallyCreateOrUpdateUserOkResult) assert user1a.user.tenant_ids == ["t1"] assert user1b.user.tenant_ids == ["t1"] @@ -81,12 +99,12 @@ async def test_thirtyparty_multitenancy_functions(): assert user3b.user.tenant_ids == ["t3"] # get user by id: - g_user1a = await get_user_by_id(user1a.user.user_id) - g_user1b = await get_user_by_id(user1b.user.user_id) - g_user2a = await get_user_by_id(user2a.user.user_id) - g_user2b = await get_user_by_id(user2b.user.user_id) - g_user3a = await get_user_by_id(user3a.user.user_id) - g_user3b = await get_user_by_id(user3b.user.user_id) + g_user1a = await get_user(user1a.user.id) + g_user1b = await get_user(user1b.user.id) + g_user2a = await get_user(user2a.user.id) + g_user2b = await get_user(user2b.user.id) + g_user3a = await get_user(user3a.user.id) + g_user3b = await get_user(user3b.user.id) assert g_user1a == user1a.user assert g_user1b == user1b.user @@ -96,28 +114,76 @@ async def test_thirtyparty_multitenancy_functions(): assert g_user3b == user3b.user # get user by email: - by_email_user1 = await get_users_by_email("t1", "test@example.com") - by_email_user2 = await get_users_by_email("t2", "test@example.com") - by_email_user3 = await get_users_by_email("t3", "test@example.com") + by_email_user1 = await list_users_by_account_info( + "t1", AccountInfo(email="test@example.com") + ) + by_email_user2 = await list_users_by_account_info( + "t2", AccountInfo(email="test@example.com") + ) + by_email_user3 = await list_users_by_account_info( + "t3", AccountInfo(email="test@example.com") + ) assert by_email_user1 == [user1a.user, user1b.user] assert by_email_user2 == [user2a.user, user2b.user] assert by_email_user3 == [user3a.user, user3b.user] # get user by thirdparty id: - g_user_by_tpid1a = await get_user_by_third_party_info("t1", "google", "googleid1") - g_user_by_tpid1b = await get_user_by_third_party_info("t1", "facebook", "fbid1") - g_user_by_tpid2a = await get_user_by_third_party_info("t2", "google", "googleid1") - g_user_by_tpid2b = await get_user_by_third_party_info("t2", "facebook", "fbid1") - g_user_by_tpid3a = await get_user_by_third_party_info("t3", "google", "googleid1") - g_user_by_tpid3b = await get_user_by_third_party_info("t3", "facebook", "fbid1") + g_user_by_tpid1a = await list_users_by_account_info( + "t1", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid1b = await list_users_by_account_info( + "t1", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) + g_user_by_tpid2a = await list_users_by_account_info( + "t2", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid2b = await list_users_by_account_info( + "t2", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) + g_user_by_tpid3a = await list_users_by_account_info( + "t3", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="google", third_party_user_id="googleid1" + ) + ), + ) + g_user_by_tpid3b = await list_users_by_account_info( + "t3", + AccountInfo( + third_party=ThirdPartyInfo( + third_party_id="facebook", third_party_user_id="fbid1" + ) + ), + ) - assert g_user_by_tpid1a == user1a.user - assert g_user_by_tpid1b == user1b.user - assert g_user_by_tpid2a == user2a.user - assert g_user_by_tpid2b == user2b.user - assert g_user_by_tpid3a == user3a.user - assert g_user_by_tpid3b == user3b.user + assert g_user_by_tpid1a == [user1a.user] + assert g_user_by_tpid1b == [user1b.user] + assert g_user_by_tpid2a == [user2a.user] + assert g_user_by_tpid2b == [user2b.user] + assert g_user_by_tpid3a == [user3a.user] + assert g_user_by_tpid3b == [user3b.user] async def test_get_provider(): @@ -157,9 +223,15 @@ async def test_get_provider(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(third_party_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(third_party_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["thirdparty"]) + ) await create_or_update_third_party_config( "t1", diff --git a/tests/thirdparty/test_thirdparty.py b/tests/thirdparty/test_thirdparty.py index 5cfd928cd..ce567f059 100644 --- a/tests/thirdparty/test_thirdparty.py +++ b/tests/thirdparty/test_thirdparty.py @@ -396,4 +396,7 @@ async def test_signinup_generating_fake_email( assert res.status_code == 200 res_json = res.json() assert res_json["status"] == "OK" - assert res_json["user"]["email"] == "customid.custom@stfakeemail.supertokens.com" + assert ( + res_json["user"]["emails"][0] == "customid.custom@stfakeemail.supertokens.com" + ) + assert len(res_json["user"]["emails"]) == 1 diff --git a/tests/useridmapping/create_user_id_mapping.py b/tests/useridmapping/create_user_id_mapping.py index a872d413a..1cff74892 100644 --- a/tests/useridmapping/create_user_id_mapping.py +++ b/tests/useridmapping/create_user_id_mapping.py @@ -55,7 +55,7 @@ async def test_create_user_id_mapping(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_user_info = "externalIdInfo" @@ -87,7 +87,7 @@ async def test_create_user_id_mapping_without_and_with_force(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" # Add metadata to the user: @@ -146,7 +146,7 @@ async def create_user_id_mapping_when_mapping_already_exists(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" # Create User ID Mapping: @@ -169,7 +169,7 @@ async def create_user_id_mapping_when_mapping_already_exists(): # Try creating a duplicate mapping where external_user_id exists and but supertokens_user_id doesn't (new) sign_up_res = await sign_up("public", "foo@bar.com", "baz") assert isinstance(sign_up_res, SignUpOkResult) - new_supertokens_user_id = sign_up_res.user.user_id + new_supertokens_user_id = sign_up_res.user.id res = await create_user_id_mapping(new_supertokens_user_id, external_user_id) assert isinstance(res, UserIdMappingAlreadyExistsError) diff --git a/tests/useridmapping/delete_user_id_mapping.py b/tests/useridmapping/delete_user_id_mapping.py index 0915d6d33..ebc3af883 100644 --- a/tests/useridmapping/delete_user_id_mapping.py +++ b/tests/useridmapping/delete_user_id_mapping.py @@ -66,7 +66,7 @@ async def test_delete_user_id_mapping(user_type: USER_TYPE): sign_up_res = await sign_up("public", "test@example.com", "password") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_id_info = "externalIdInfo" @@ -104,7 +104,7 @@ async def test_delete_user_id_mapping_without_and_with_force(): sign_up_res = await sign_up("public", "test@example.com", "testPass123") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_user_info = "externalIdInfo" diff --git a/tests/useridmapping/get_user_id_mapping.py b/tests/useridmapping/get_user_id_mapping.py index 517d076f4..030f1824f 100644 --- a/tests/useridmapping/get_user_id_mapping.py +++ b/tests/useridmapping/get_user_id_mapping.py @@ -62,7 +62,7 @@ async def test_get_user_id_mapping(use_external_id_info: bool): sign_up_res = await sign_up("public", "test@example.com", "password") assert isinstance(sign_up_res, SignUpOkResult) - supertokens_user_id = sign_up_res.user.user_id + supertokens_user_id = sign_up_res.user.id external_user_id = "externalId" external_id_info = "externalIdInfo" if use_external_id_info else None diff --git a/tests/useridmapping/recipe_tests.py b/tests/useridmapping/recipe_tests.py index 42e2d1c94..20b1309e0 100644 --- a/tests/useridmapping/recipe_tests.py +++ b/tests/useridmapping/recipe_tests.py @@ -21,10 +21,11 @@ from supertokens_python.querier import Querier from supertokens_python.recipe.emailpassword.interfaces import ( SignUpOkResult, - ResetPasswordUsingTokenOkResult, SignInOkResult, CreateResetPasswordOkResult, + UpdateEmailOrPasswordOkResult, ) +from supertokens_python.types import AccountInfo, RecipeUserId from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from .utils import st_config @@ -55,23 +56,23 @@ async def ep_get_new_user_id(email: str) -> str: sign_up_res = await sign_up("public", email, "password") assert isinstance(sign_up_res, SignUpOkResult) - return sign_up_res.user.user_id + return sign_up_res.user.id async def ep_get_existing_user_id(user_id: str) -> str: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_id + from supertokens_python.asyncio import get_user - res = await get_user_by_id(user_id) + res = await get_user(user_id) assert res is not None - return res.user_id + return res.id async def ep_get_existing_user_by_email(email: str) -> str: - from supertokens_python.recipe.emailpassword.asyncio import get_user_by_email + from supertokens_python.asyncio import list_users_by_account_info - res = await get_user_by_email("public", email) - assert res is not None - return res.user_id + res = await list_users_by_account_info("public", AccountInfo(email=email)) + assert len(res) == 1 + return res[0].id async def ep_get_existing_user_by_signin(email: str) -> str: @@ -79,22 +80,21 @@ async def ep_get_existing_user_by_signin(email: str) -> str: res = await sign_in("public", email, "password") assert isinstance(res, SignInOkResult) - return res.user.user_id + return res.user.id async def ep_get_existing_user_after_reset_password(user_id: str) -> str: - new_password = "password" + new_password = "password1234" from supertokens_python.recipe.emailpassword.asyncio import ( create_reset_password_token, reset_password_using_token, ) - result = await create_reset_password_token("public", user_id) + result = await create_reset_password_token("public", user_id, "") assert isinstance(result, CreateResetPasswordOkResult) res = await reset_password_using_token("public", result.token, new_password) - assert isinstance(res, ResetPasswordUsingTokenOkResult) - assert res.user_id is not None - return res.user_id + assert isinstance(res, UpdateEmailOrPasswordOkResult) + return user_id async def ep_get_existing_user_after_updating_email_and_sign_in(user_id: str) -> str: @@ -105,12 +105,14 @@ async def ep_get_existing_user_after_updating_email_and_sign_in(user_id: str) -> sign_in, ) - res = await update_email_or_password(user_id, new_email, "password") - assert isinstance(res, SignUpOkResult) + res = await update_email_or_password( + RecipeUserId(user_id), new_email, "password1234" + ) + assert isinstance(res, UpdateEmailOrPasswordOkResult) - res = await sign_in("public", new_email, "password") + res = await sign_in("public", new_email, "password1234") assert isinstance(res, SignInOkResult) - return res.user.user_id + return res.user.id @mark.parametrize("use_external_id_info", [(True,), (False,)]) @@ -128,7 +130,7 @@ async def test_get_user_id_mapping(use_external_id_info: bool): external_user_id = "externalId" external_id_info = "externalIdInfo" if use_external_id_info else None - assert ep_get_existing_user_id(supertokens_user_id) == supertokens_user_id + assert await ep_get_existing_user_id(supertokens_user_id) == supertokens_user_id # Create user id mapping res = await create_user_id_mapping( @@ -138,16 +140,17 @@ async def test_get_user_id_mapping(use_external_id_info: bool): # Now we should get the external user ID instead of ST user ID # irrespective of whether we pass ST User ID or External User ID - assert ep_get_existing_user_id(supertokens_user_id) == external_user_id - assert ep_get_existing_user_id(external_user_id) == external_user_id + assert await ep_get_existing_user_id(supertokens_user_id) == external_user_id + assert await ep_get_existing_user_id(external_user_id) == external_user_id # Same happens for all the functions - assert ep_get_existing_user_by_email(email) == external_user_id - assert ep_get_existing_user_by_signin(email) == external_user_id + assert await ep_get_existing_user_by_email(email) == external_user_id + assert await ep_get_existing_user_by_signin(email) == external_user_id assert ( - ep_get_existing_user_after_reset_password(external_user_id) == external_user_id + await ep_get_existing_user_after_reset_password(external_user_id) + == external_user_id ) assert ( - ep_get_existing_user_after_updating_email_and_sign_in(external_user_id) + await ep_get_existing_user_after_updating_email_and_sign_in(external_user_id) == external_user_id ) diff --git a/tests/userroles/test_claims.py b/tests/userroles/test_claims.py index ed0160bea..db13577fb 100644 --- a/tests/userroles/test_claims.py +++ b/tests/userroles/test_claims.py @@ -19,6 +19,7 @@ from supertokens_python import init from supertokens_python.recipe import userroles, session from supertokens_python.recipe.session.exceptions import ClaimValidationError +from supertokens_python.types import RecipeUserId from tests.utils import ( start_st, setup_function, @@ -56,7 +57,7 @@ async def test_add_claims_to_session_without_config(): user_id = "userId" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert s.sync_get_claim_value(UserRoleClaim) == [] assert (await s.get_claim_value(PermissionClaim)) == [] @@ -78,7 +79,7 @@ async def test_claims_not_added_to_session_if_disabled(): user_id = "userId" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert (await s.get_claim_value(UserRoleClaim)) is None assert s.sync_get_claim_value(PermissionClaim) is None @@ -101,7 +102,7 @@ async def test_add_claims_to_session_with_values(): await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) assert s.sync_get_claim_value(UserRoleClaim) == [role] value: List[str] = await s.get_claim_value(PermissionClaim) # type: ignore assert sorted(value) == sorted(["a", "b"]) @@ -126,7 +127,7 @@ async def test_should_validate_roles(): await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await s.assert_claims([UserRoleClaim.validators.includes(role)]) with pytest.raises(Exception) as e: @@ -134,7 +135,8 @@ async def test_should_validate_roles(): assert e.typename == "InvalidClaimsError" err: ClaimValidationError (err,) = e.value.payload # type: ignore - assert err.id == UserRoleClaim.key + assert isinstance(err, ClaimValidationError) + assert err.id_ == UserRoleClaim.key assert err.reason == { "message": "wrong value", "expectedToInclude": invalid_role, @@ -159,7 +161,7 @@ async def test_should_validate_roles_after_refetch(): role = "role" req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await create_new_role_or_add_permissions(role, ["a", "b"]) await add_role_to_user("public", user_id, role) @@ -187,7 +189,7 @@ async def test_should_validate_permissions(): await create_new_role_or_add_permissions(role, permissions) await add_role_to_user("public", user_id, role) - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await s.assert_claims([PermissionClaim.validators.includes("a")]) with pytest.raises(Exception) as e: @@ -195,8 +197,10 @@ async def test_should_validate_permissions(): assert e.typename == "InvalidClaimsError" err: ClaimValidationError (err,) = e.value.payload # type: ignore - assert err.id == PermissionClaim.key + assert isinstance(err, ClaimValidationError) + assert err.id_ == PermissionClaim.key assert err.reason is not None + assert isinstance(err.reason, dict) actual_value = err.reason.pop("actualValue") assert sorted(actual_value) == sorted(permissions) assert err.reason == { @@ -223,7 +227,7 @@ async def test_should_validate_permissions_after_refetch(): permissions = ["a", "b"] req = MagicMock() - s = await create_new_session(req, "public", user_id) + s = await create_new_session(req, "public", RecipeUserId(user_id)) await create_new_role_or_add_permissions(role, permissions) await add_role_to_user("public", user_id, role) diff --git a/tests/userroles/test_multitenancy.py b/tests/userroles/test_multitenancy.py index 54017f1ff..51cb5d133 100644 --- a/tests/userroles/test_multitenancy.py +++ b/tests/userroles/test_multitenancy.py @@ -20,12 +20,15 @@ ) from supertokens_python.recipe.emailpassword.asyncio import sign_up from supertokens_python.recipe.emailpassword.interfaces import SignUpOkResult -from supertokens_python.recipe.multitenancy.interfaces import TenantConfig +from supertokens_python.recipe.multitenancy.interfaces import ( + TenantConfigCreateOrUpdate, +) from supertokens_python.recipe.userroles.asyncio import ( create_new_role_or_add_permissions, add_role_to_user, get_roles_for_user, ) +from supertokens_python.types import RecipeUserId from tests.utils import get_st_init_args from tests.utils import ( @@ -56,17 +59,23 @@ async def test_multitenancy_in_user_roles(): start_st() setup_multitenancy_feature() - await create_or_update_tenant("t1", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t2", TenantConfig(email_password_enabled=True)) - await create_or_update_tenant("t3", TenantConfig(email_password_enabled=True)) + await create_or_update_tenant( + "t1", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t2", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) + await create_or_update_tenant( + "t3", TenantConfigCreateOrUpdate(first_factors=["emailpassword"]) + ) user = await sign_up("public", "test@example.com", "password1") assert isinstance(user, SignUpOkResult) - user_id = user.user.user_id + user_id = user.user.id - await associate_user_to_tenant("t1", user_id) - await associate_user_to_tenant("t2", user_id) - await associate_user_to_tenant("t3", user_id) + await associate_user_to_tenant("t1", RecipeUserId(user_id)) + await associate_user_to_tenant("t2", RecipeUserId(user_id)) + await associate_user_to_tenant("t3", RecipeUserId(user_id)) await create_new_role_or_add_permissions("role1", []) await create_new_role_or_add_permissions("role2", []) diff --git a/tests/utils.py b/tests/utils.py index 76115c27c..f7493d053 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -41,6 +41,9 @@ from supertokens_python.recipe.usermetadata import UserMetadataRecipe from supertokens_python.recipe.userroles import UserRolesRecipe from supertokens_python.utils import is_version_gte +from supertokens_python.recipe.accountlinking.recipe import AccountLinkingRecipe +from supertokens_python.recipe.multifactorauth.recipe import MultiFactorAuthRecipe +from supertokens_python.recipe.totp.recipe import TOTPRecipe INSTALLATION_PATH = environ["SUPERTOKENS_PATH"] SUPERTOKENS_PROCESS_DIR = INSTALLATION_PATH + "/.started" @@ -220,6 +223,9 @@ def reset(stop_core: bool = True): DashboardRecipe.reset() PasswordlessRecipe.reset() MultitenancyRecipe.reset() + AccountLinkingRecipe.reset() + MultiFactorAuthRecipe.reset() + TOTPRecipe.reset() def get_cookie_from_response(response: Response, cookie_name: str): @@ -266,12 +272,12 @@ def extract_info(response: Response) -> Dict[str, Any]: "antiCsrf": response.headers.get("anti-csrf"), "accessTokenFromHeader": access_token_from_header, "refreshTokenFromHeader": refresh_token_from_header, - "accessTokenFromAny": access_token_from_header - if access_token is None - else access_token, - "refreshTokenFromAny": refresh_token_from_header - if refresh_token is None - else refresh_token, + "accessTokenFromAny": ( + access_token_from_header if access_token is None else access_token + ), + "refreshTokenFromAny": ( + refresh_token_from_header if refresh_token is None else refresh_token + ), } @@ -600,5 +606,5 @@ async def create_users( ) elif user["recipe"] == "thirdparty" and thirdparty: await manually_create_or_update_user( - "public", user["provider"], user["userId"], user["email"] + "public", user["provider"], user["userId"], user["email"], True, None )