diff --git a/supertokens_python/recipe/emailverification/recipe_implementation.py b/supertokens_python/recipe/emailverification/recipe_implementation.py index fce767d1..ea5ab3fb 100644 --- a/supertokens_python/recipe/emailverification/recipe_implementation.py +++ b/supertokens_python/recipe/emailverification/recipe_implementation.py @@ -31,10 +31,10 @@ ) from .types import EmailVerificationUser from supertokens_python.asyncio import get_user +from supertokens_python.types import RecipeUserId, User if TYPE_CHECKING: from supertokens_python.querier import Querier - from supertokens_python.types import RecipeUserId, User class RecipeImplementation(RecipeInterface): diff --git a/supertokens_python/types.py b/supertokens_python/types.py index 2eb5535c..e4b5ee63 100644 --- a/supertokens_python/types.py +++ b/supertokens_python/types.py @@ -70,7 +70,23 @@ def __init__( self.time_joined = time_joined self.verified = verified + def __eq__(self, other: Any) -> bool: + if isinstance(other, LoginMethod): + return ( + self.recipe_id == other.recipe_id + and self.recipe_user_id == other.recipe_user_id + and self.tenant_ids == other.tenant_ids + and self.has_same_email_as(other.email) + and self.has_same_phone_number_as(other.phone_number) + and self.has_same_third_party_info_as(other.third_party) + and self.time_joined == other.time_joined + and self.verified == other.verified + ) + return False + def has_same_email_as(self, email: Union[str, None]) -> bool: + if self.email is None and email is None: + return True if email is None: return False return ( @@ -79,6 +95,8 @@ def has_same_email_as(self, email: Union[str, None]) -> bool: ) def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: + if self.phone_number is None and phone_number is None: + return True if phone_number is None: return False @@ -95,6 +113,8 @@ def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool: def has_same_third_party_info_as( self, third_party: Union[ThirdPartyInfo, None] ) -> bool: + if third_party is None and self.third_party is None: + return True if third_party is None or self.third_party is None: return False return ( @@ -157,6 +177,20 @@ def __init__( self.login_methods = login_methods self.time_joined = time_joined + def __eq__(self, other: Any) -> bool: + if isinstance(other, User): + return ( + self.id == other.id + and self.is_primary_user == other.is_primary_user + and self.tenant_ids == other.tenant_ids + and self.emails == other.emails + and self.phone_numbers == other.phone_numbers + and self.third_party == other.third_party + and self.login_methods == other.login_methods + and self.time_joined == other.time_joined + ) + return False + def to_json(self) -> Dict[str, Any]: return { "id": self.id, diff --git a/tests/emailpassword/test_emaildelivery.py b/tests/emailpassword/test_emaildelivery.py index d24e88e5..13c21695 100644 --- a/tests/emailpassword/test_emaildelivery.py +++ b/tests/emailpassword/test_emaildelivery.py @@ -20,6 +20,7 @@ import respx from fastapi import FastAPI from fastapi.requests import Request +from supertokens_python.types import RecipeUserId from tests.testclient import TestClientWithNoCookieJar as TestClient from supertokens_python import InputAppInfo, SupertokensConfig, init @@ -206,7 +207,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -250,13 +253,14 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email assert isinstance(template_vars, PasswordResetEmailTemplateVars) password_reset_url = template_vars.password_reset_link - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -319,11 +323,12 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]): oi_send_email = oi.send_email async def send_email( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): template_vars.user.email = "override@example.com" assert isinstance(template_vars, PasswordResetEmailTemplateVars) - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email return oi @@ -332,7 +337,9 @@ class CustomEmailService( emailpassword.EmailDeliveryInterface[emailpassword.EmailTemplateVars] ): async def send_email( - self, template_vars: Any, user_context: Dict[str, Any] + self, + template_vars: Any, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> None: nonlocal email, password_reset_url email = template_vars.user.email @@ -384,7 +391,10 @@ async def test_reset_password_smtp_service(driver_config_client: TestClient): def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -396,7 +406,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -433,11 +444,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -486,7 +498,8 @@ async def test_reset_password_for_non_existent_user(driver_config_client: TestCl def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -498,7 +511,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, password_reset_url get_content_called = True @@ -535,11 +549,12 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: EmailTemplateVars, _user_context: Dict[str, Any] + template_vars: EmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True - await oi_send_email(template_vars, _user_context) + await oi_send_email(template_vars, user_context) oi.send_email = send_email_override return oi @@ -613,7 +628,7 @@ async def test_email_verification_default_backward_compatibility( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -679,7 +694,7 @@ async def test_email_verification_default_backward_compatibility_suppress_error( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -727,7 +742,9 @@ class CustomEmailService( async def send_email( self, template_vars: emailverification.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ) -> None: nonlocal email, email_verify_url email = template_vars.user.email @@ -762,7 +779,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) res = email_verify_token_request( @@ -792,7 +809,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal email, email_verify_url email = template_vars.user.email @@ -833,7 +851,7 @@ async def send_email( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) def api_side_effect(request: httpx.Request): @@ -878,7 +896,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: emailpassword.EmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal email, password_reset_url email = template_vars.user.email @@ -925,7 +945,8 @@ async def test_email_verification_smtp_service(driver_config_client: TestClient) def smtp_service_override(oi: SMTPServiceInterface[VerificationEmailTemplateVars]): async def send_raw_email_override( - content: EmailContent, _user_context: Dict[str, Any] + content: EmailContent, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal send_raw_email_called, email send_raw_email_called = True @@ -937,7 +958,8 @@ async def send_raw_email_override( # Note that we aren't calling oi.send_raw_email. So Transporter won't be used. async def get_content_override( - template_vars: VerificationEmailTemplateVars, _user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ) -> EmailContent: nonlocal get_content_called, email_verify_url get_content_called = True @@ -974,7 +996,8 @@ def email_delivery_override( oi_send_email = oi.send_email async def send_email_override( - template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any] + template_vars: VerificationEmailTemplateVars, + user_context: Dict[str, Any], # pylint: disable=unused-argument ): nonlocal outer_override_called outer_override_called = True @@ -1013,7 +1036,7 @@ async def send_email_override( if not isinstance(s.recipe_implementation, SessionRecipeImplementation): raise Exception("Should never come here") response = await create_new_session( - s.recipe_implementation, "public", user_id, True, {}, {}, None + s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None ) resp = email_verify_token_request( @@ -1045,7 +1068,9 @@ class CustomEmailDeliveryService( async def send_email( self, template_vars: PasswordResetEmailTemplateVars, - user_context: Dict[str, Any], + user_context: Dict[ + str, Any + ], # pylint: disable=unused-argument, # pylint: disable=unused-argument ): nonlocal reset_url, token_info, tenant_info, query_length password_reset_url = template_vars.password_reset_link @@ -1081,7 +1106,9 @@ async def send_email( dict_response = json.loads(response_1.text) user_info = dict_response["user"] assert dict_response["status"] == "OK" - resp = await send_reset_password_email("public", user_info["id"], "") + resp = await send_reset_password_email( + "public", user_info["id"], "random@gmail.com" + ) assert resp == "OK" assert reset_url == "http://supertokens.io/auth/reset-password" @@ -1118,9 +1145,13 @@ async def test_send_reset_password_email_invalid_input( dict_response = json.loads(response_1.text) user_info = dict_response["user"] - link = await send_reset_password_email("public", "invalidUserId", "") + link = await send_reset_password_email( + "public", "invalidUserId", "random@gmail.com" + ) assert link == "UNKNOWN_USER_ID_ERROR" with raises(Exception) as err: - await send_reset_password_email("invalidTenantId", user_info["id"], "") + await send_reset_password_email( + "invalidTenantId", user_info["id"], "random@gmail.com" + ) assert "status code: 400" in str(err.value) diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 0bf066d6..a87ed908 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -198,7 +198,9 @@ async def test_the_generate_token_api_with_valid_input_email_verified_and_test_e user_id = dict_response["user"]["id"] cookies = extract_all_cookies(response_1) - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) @@ -742,7 +744,7 @@ async def email_verify_post( await asyncio.sleep(1) if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -972,7 +974,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1083,7 +1085,7 @@ async def email_verify_post( if user_info_from_callback is None: raise Exception("Should never come here") - assert user_info_from_callback.user_id == user_id # type: ignore + assert user_info_from_callback.recipe_user_id.get_as_string() == user_id # type: ignore assert user_info_from_callback.email == "test@gmail.com" # type: ignore @@ -1121,7 +1123,9 @@ async def test_the_generate_token_api_with_valid_input_and_then_remove_token( assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) await revoke_email_verification_tokens("public", user_id) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): @@ -1165,7 +1169,9 @@ async def test_the_generate_token_api_with_valid_input_verify_and_then_unverify_ assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): await verify_email_using_token("public", verify_token.token) @@ -1269,7 +1275,9 @@ async def send_email( cookies = extract_all_cookies(res) # Start verification: - verify_token = await create_email_verification_token("public", user_id) + verify_token = await create_email_verification_token( + "public", RecipeUserId(user_id) + ) assert isinstance(verify_token, CreateEmailVerificationTokenOkResult) await verify_email_using_token("public", verify_token.token) @@ -1370,7 +1378,7 @@ def get_origin(_: Optional[BaseRequest], user_context: Dict[str, Any]) -> str: dict_response = json.loads(response_1.text) assert dict_response["status"] == "OK" user_id = dict_response["user"]["id"] - email = dict_response["user"]["email"] + email = dict_response["user"]["emails"][0] await send_email_verification_email( "public", user_id, RecipeUserId(user_id), email, {"url": "localhost:3000"}