diff --git a/tests/test-server/app.py b/tests/test-server/app.py index 2cf199fb..bd7489c8 100644 --- a/tests/test-server/app.py +++ b/tests/test-server/app.py @@ -141,7 +141,6 @@ def builder(oI: T) -> T: for attr in dir(oI) if callable(getattr(oI, attr)) and not attr.startswith("__") ] - for member in members: create_override(oI, member, name, override_name) return oI @@ -340,7 +339,7 @@ def init_st(config: Dict[str, Any]): ), ) - include_in_non_public_tenants_by_default = None + include_in_non_public_tenants_by_default = False if "includeInNonPublicTenantsByDefault" in provider: include_in_non_public_tenants_by_default = provider[ @@ -391,6 +390,10 @@ def init_st(config: Dict[str, Any]): ), ), include_in_non_public_tenants_by_default=include_in_non_public_tenants_by_default, + override=override_builder_with_logging( + "ThirdParty.providers.override", + provider.get("override", None), + ), ) providers.append(provider_input) recipe_list.append( @@ -662,6 +665,9 @@ def not_found(error: Any) -> Any: # pylint: disable=unused-argument add_accountlinking_routes(app) add_passwordless_routes(app) add_totp_routes(app) +from supertokens import add_supertokens_routes # pylint: disable=import-error + +add_supertokens_routes(app) if __name__ == "__main__": default_st_init() diff --git a/tests/test-server/session.py b/tests/test-server/session.py index 00c5dcda..20bc0080 100644 --- a/tests/test-server/session.py +++ b/tests/test-server/session.py @@ -1,6 +1,7 @@ -from typing import Any +from typing import Any, Dict from flask import Flask, request, jsonify from override_logging import log_override_event # pylint: disable=import-error +from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.interfaces import TokenInfo from supertokens_python.recipe.session.jwt import ( parse_jwt_without_signature_verification, @@ -50,27 +51,7 @@ def create_new_session_without_request_response(): # type: ignore user_context, ) - return jsonify( - { - "sessionHandle": session_container.get_handle(), - "userId": session_container.get_user_id(), - "tenantId": session_container.get_tenant_id(), - "userDataInAccessToken": session_container.get_access_token_payload(), - "accessToken": session_container.get_access_token(), - "frontToken": session_container.get_all_session_tokens_dangerously()[ - "frontToken" - ], - "refreshToken": session_container.get_all_session_tokens_dangerously()[ - "refreshToken" - ], - "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ - "antiCsrfToken" - ], - "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ - "accessAndFrontTokenUpdated" - ], - } - ) + return jsonify(convert_session_to_json(session_container)) @app.route("/test/session/getsessionwithoutrequestresponse", methods=["POST"]) # type: ignore def get_session_without_request_response(): # type: ignore @@ -83,13 +64,14 @@ def get_session_without_request_response(): # type: ignore options = data.get("options") user_context = data.get("userContext", {}) - try: - session_container = session.get_session_without_request_response( - access_token, anti_csrf_token, options, user_context - ) - return jsonify(session_container) - except Exception as e: - return jsonify({"error": str(e)}), 500 + session_container = session.get_session_without_request_response( + access_token, anti_csrf_token, options, user_context + ) + return jsonify( + None + if session_container is None + else convert_session_to_json(session_container) + ) @app.route("/test/session/sessionobject/assertclaims", methods=["POST"]) # type: ignore def assert_claims(): # type: ignore @@ -213,6 +195,28 @@ def fetch_and_set_claim_api(): # type: ignore return jsonify(response) +def convert_session_to_json(session_container: SessionContainer) -> Dict[str, Any]: + return { + "sessionHandle": session_container.get_handle(), + "userId": session_container.get_user_id(), + "tenantId": session_container.get_tenant_id(), + "userDataInAccessToken": session_container.get_access_token_payload(), + "accessToken": session_container.get_access_token(), + "frontToken": session_container.get_all_session_tokens_dangerously()[ + "frontToken" + ], + "refreshToken": session_container.get_all_session_tokens_dangerously()[ + "refreshToken" + ], + "antiCsrfToken": session_container.get_all_session_tokens_dangerously()[ + "antiCsrfToken" + ], + "accessTokenUpdated": session_container.get_all_session_tokens_dangerously()[ + "accessAndFrontTokenUpdated" + ], + } + + def convert_session_to_container(data: Any) -> Session: jwt_info = parse_jwt_without_signature_verification(data["session"]["accessToken"]) jwt_payload = jwt_info.payload diff --git a/tests/test-server/supertokens.py b/tests/test-server/supertokens.py new file mode 100644 index 00000000..3cc40d50 --- /dev/null +++ b/tests/test-server/supertokens.py @@ -0,0 +1,89 @@ +from flask import Flask, request, jsonify +from supertokens_python.recipe.thirdparty.types import ThirdPartyInfo +from supertokens_python.types import AccountInfo +from supertokens_python.syncio import ( + get_user, + delete_user, + list_users_by_account_info, + get_users_newest_first, + get_users_oldest_first, +) + + +def add_supertokens_routes(app: Flask): + @app.route("/test/supertokens/getuser", methods=["POST"]) # type: ignore + def get_user_api(): # type: ignore + assert request.json is not None + response = get_user(request.json["userId"], request.json.get("userContext")) + return jsonify(None if response is None else response.to_json()) + + @app.route("/test/supertokens/deleteuser", methods=["POST"]) # type: ignore + def delete_user_api(): # type: ignore + assert request.json is not None + delete_user( + request.json["userId"], + request.json["removeAllLinkedAccounts"], + request.json.get("userContext"), + ) + return jsonify({"status": "OK"}) + + @app.route("/test/supertokens/listusersbyaccountinfo", methods=["POST"]) # type: ignore + def list_users_by_account_info_api(): # type: ignore + assert request.json is not None + response = list_users_by_account_info( + request.json["tenantId"], + AccountInfo( + email=request.json["accountInfo"].get("email", None), + phone_number=request.json["accountInfo"].get("phoneNumber", None), + third_party=( + None + if "thirdParty" not in request.json["accountInfo"] + else ThirdPartyInfo( + third_party_id=request.json["accountInfo"]["thirdParty"][ + "thirdPartyId" + ], + third_party_user_id=request.json["accountInfo"]["thirdParty"][ + "id" + ], + ) + ), + ), + request.json["doUnionOfAccountInfo"], + request.json.get("userContext"), + ) + + return jsonify([r.to_json() for r in response]) + + @app.route("/test/supertokens/getusersnewestfirst", methods=["POST"]) # type: ignore + def get_users_newest_first_api(): # type: ignore + assert request.json is not None + response = get_users_newest_first( + include_recipe_ids=request.json["includeRecipeIds"], + limit=request.json["limit"], + pagination_token=request.json["paginationToken"], + tenant_id=request.json["tenantId"], + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) + + @app.route("/test/supertokens/getusersoldestfirst", methods=["POST"]) # type: ignore + def get_users_oldest_first_api(): # type: ignore + assert request.json is not None + response = get_users_oldest_first( + include_recipe_ids=request.json["includeRecipeIds"], + limit=request.json["limit"], + pagination_token=request.json["paginationToken"], + tenant_id=request.json["tenantId"], + user_context=request.json.get("userContext"), + ) + return jsonify( + { + "nextPaginationToken": response.next_pagination_token, + "users": [r.to_json() for r in response.users], + } + ) diff --git a/tests/test-server/test_functions_mapper.py b/tests/test-server/test_functions_mapper.py index 188e9728..63f32d0f 100644 --- a/tests/test-server/test_functions_mapper.py +++ b/tests/test-server/test_functions_mapper.py @@ -6,6 +6,11 @@ ShouldAutomaticallyLink, ShouldNotAutomaticallyLink, ) +from supertokens_python.recipe.thirdparty.types import ( + RawUserInfoFromProvider, + UserInfo, + UserInfoEmail, +) from supertokens_python.types import AccountInfo, RecipeUserId from supertokens_python.types import APIResponse, User @@ -124,7 +129,164 @@ async def func( return func - raise Exception("Unknown eval string") + if eval_str.startswith("thirdparty.init.signInAndUpFeature.providers"): + + def custom_provider(provider: Any): + if "custom-ev" in eval_str: + + def exchange_auth_code_for_oauth_tokens1( + redirect_uri_info: Any, # pylint: disable=unused-argument + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return {} + + def get_user_info1( + oauth_tokens: Any, + user_context: Any, # pylint: disable=unused-argument + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens1 + ) + provider.get_user_info = get_user_info1 + return provider + + if "custom-no-ev" in eval_str: + + def exchange_auth_code_for_oauth_tokens2( + redirect_uri_info: Any, # pylint: disable=unused-argument + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return {} + + def get_user_info2( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "user"), + email=UserInfoEmail( + email=oauth_tokens.get("email", "email@test.com"), + is_verified=False, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens2 + ) + provider.get_user_info = get_user_info2 + return provider + + if "custom2" in eval_str: + + def exchange_auth_code_for_oauth_tokens3( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info["redirectURIQueryParams"] + + def get_user_info3( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=f"custom2{oauth_tokens['email']}", + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens3 + ) + provider.get_user_info = get_user_info3 + return provider + + if "custom3" in eval_str: + + def exchange_auth_code_for_oauth_tokens4( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info["redirectURIQueryParams"] + + def get_user_info4( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + return UserInfo( + third_party_user_id=oauth_tokens["email"], + email=UserInfoEmail( + email=oauth_tokens["email"], + is_verified=True, + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens4 + ) + provider.get_user_info = get_user_info4 + return provider + + if "custom" in eval_str: + + def exchange_auth_code_for_oauth_tokens5( + redirect_uri_info: Any, + user_context: Any, # pylint: disable=unused-argument + ) -> Any: + return redirect_uri_info + + def get_user_info5( + oauth_tokens: Any, user_context: Any + ): # pylint: disable=unused-argument + if oauth_tokens.get("error"): + raise Exception("Credentials error") + return UserInfo( + third_party_user_id=oauth_tokens.get("userId", "userId"), + email=( + None + if oauth_tokens.get("email") is None + else UserInfoEmail( + email=oauth_tokens.get("email"), + is_verified=oauth_tokens.get("isVerified", False), + ) + ), + raw_user_info_from_provider=RawUserInfoFromProvider( + from_id_token_payload=None, + from_user_info_api=None, + ), + ) + + provider.exchange_auth_code_for_oauth_tokens = ( + exchange_auth_code_for_oauth_tokens5 + ) + provider.get_user_info = get_user_info5 + return provider + + return custom_provider + + raise Exception("Unknown eval string: " + eval_str) class OverrideParams(APIResponse):