Skip to content

Commit

Permalink
fixes more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Sep 25, 2024
1 parent 76cdb54 commit 160af9b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
)
from .types import EmailVerificationUser
from supertokens_python.asyncio import get_user
from supertokens_python.types import RecipeUserId, User

if TYPE_CHECKING:
from supertokens_python.querier import Querier
from supertokens_python.types import RecipeUserId, User


class RecipeImplementation(RecipeInterface):
Expand Down
34 changes: 34 additions & 0 deletions supertokens_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,23 @@ def __init__(
self.time_joined = time_joined
self.verified = verified

def __eq__(self, other: Any) -> bool:
if isinstance(other, LoginMethod):
return (
self.recipe_id == other.recipe_id
and self.recipe_user_id == other.recipe_user_id
and self.tenant_ids == other.tenant_ids
and self.has_same_email_as(other.email)
and self.has_same_phone_number_as(other.phone_number)
and self.has_same_third_party_info_as(other.third_party)
and self.time_joined == other.time_joined
and self.verified == other.verified
)
return False

def has_same_email_as(self, email: Union[str, None]) -> bool:
if self.email is None and email is None:
return True
if email is None:
return False
return (
Expand All @@ -79,6 +95,8 @@ def has_same_email_as(self, email: Union[str, None]) -> bool:
)

def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool:
if self.phone_number is None and phone_number is None:
return True
if phone_number is None:
return False

Expand All @@ -95,6 +113,8 @@ def has_same_phone_number_as(self, phone_number: Union[str, None]) -> bool:
def has_same_third_party_info_as(
self, third_party: Union[ThirdPartyInfo, None]
) -> bool:
if third_party is None and self.third_party is None:
return True
if third_party is None or self.third_party is None:
return False
return (
Expand Down Expand Up @@ -157,6 +177,20 @@ def __init__(
self.login_methods = login_methods
self.time_joined = time_joined

def __eq__(self, other: Any) -> bool:
if isinstance(other, User):
return (
self.id == other.id
and self.is_primary_user == other.is_primary_user
and self.tenant_ids == other.tenant_ids
and self.emails == other.emails
and self.phone_numbers == other.phone_numbers
and self.third_party == other.third_party
and self.login_methods == other.login_methods
and self.time_joined == other.time_joined
)
return False

def to_json(self) -> Dict[str, Any]:
return {
"id": self.id,
Expand Down
89 changes: 60 additions & 29 deletions tests/emailpassword/test_emaildelivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import respx
from fastapi import FastAPI
from fastapi.requests import Request
from supertokens_python.types import RecipeUserId
from tests.testclient import TestClientWithNoCookieJar as TestClient

from supertokens_python import InputAppInfo, SupertokensConfig, init
Expand Down Expand Up @@ -206,7 +207,9 @@ class CustomEmailService(
async def send_email(
self,
template_vars: emailpassword.EmailTemplateVars,
user_context: Dict[str, Any],
user_context: Dict[
str, Any
], # pylint: disable=unused-argument, # pylint: disable=unused-argument
) -> None:
nonlocal email, password_reset_url
email = template_vars.user.email
Expand Down Expand Up @@ -250,13 +253,14 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]):
oi_send_email = oi.send_email

async def send_email(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal email, password_reset_url
email = template_vars.user.email
assert isinstance(template_vars, PasswordResetEmailTemplateVars)
password_reset_url = template_vars.password_reset_link
await oi_send_email(template_vars, _user_context)
await oi_send_email(template_vars, user_context)

oi.send_email = send_email
return oi
Expand Down Expand Up @@ -319,11 +323,12 @@ def email_delivery_override(oi: EmailDeliveryInterface[EmailTemplateVars]):
oi_send_email = oi.send_email

async def send_email(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
template_vars.user.email = "[email protected]"
assert isinstance(template_vars, PasswordResetEmailTemplateVars)
await oi_send_email(template_vars, _user_context)
await oi_send_email(template_vars, user_context)

oi.send_email = send_email
return oi
Expand All @@ -332,7 +337,9 @@ class CustomEmailService(
emailpassword.EmailDeliveryInterface[emailpassword.EmailTemplateVars]
):
async def send_email(
self, template_vars: Any, user_context: Dict[str, Any]
self,
template_vars: Any,
user_context: Dict[str, Any], # pylint: disable=unused-argument
) -> None:
nonlocal email, password_reset_url
email = template_vars.user.email
Expand Down Expand Up @@ -384,7 +391,10 @@ async def test_reset_password_smtp_service(driver_config_client: TestClient):

def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]):
async def send_raw_email_override(
content: EmailContent, _user_context: Dict[str, Any]
content: EmailContent,
user_context: Dict[
str, Any
], # pylint: disable=unused-argument, # pylint: disable=unused-argument
):
nonlocal send_raw_email_called, email
send_raw_email_called = True
Expand All @@ -396,7 +406,8 @@ async def send_raw_email_override(
# Note that we aren't calling oi.send_raw_email. So Transporter won't be used.

async def get_content_override(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
) -> EmailContent:
nonlocal get_content_called, password_reset_url
get_content_called = True
Expand Down Expand Up @@ -433,11 +444,12 @@ def email_delivery_override(
oi_send_email = oi.send_email

async def send_email_override(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal outer_override_called
outer_override_called = True
await oi_send_email(template_vars, _user_context)
await oi_send_email(template_vars, user_context)

oi.send_email = send_email_override
return oi
Expand Down Expand Up @@ -486,7 +498,8 @@ async def test_reset_password_for_non_existent_user(driver_config_client: TestCl

def smtp_service_override(oi: SMTPServiceInterface[EmailTemplateVars]):
async def send_raw_email_override(
content: EmailContent, _user_context: Dict[str, Any]
content: EmailContent,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal send_raw_email_called, email
send_raw_email_called = True
Expand All @@ -498,7 +511,8 @@ async def send_raw_email_override(
# Note that we aren't calling oi.send_raw_email. So Transporter won't be used.

async def get_content_override(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
) -> EmailContent:
nonlocal get_content_called, password_reset_url
get_content_called = True
Expand Down Expand Up @@ -535,11 +549,12 @@ def email_delivery_override(
oi_send_email = oi.send_email

async def send_email_override(
template_vars: EmailTemplateVars, _user_context: Dict[str, Any]
template_vars: EmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal outer_override_called
outer_override_called = True
await oi_send_email(template_vars, _user_context)
await oi_send_email(template_vars, user_context)

oi.send_email = send_email_override
return oi
Expand Down Expand Up @@ -613,7 +628,7 @@ async def test_email_verification_default_backward_compatibility(
if not isinstance(s.recipe_implementation, SessionRecipeImplementation):
raise Exception("Should never come here")
response = await create_new_session(
s.recipe_implementation, "public", user_id, True, {}, {}, None
s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None
)

def api_side_effect(request: httpx.Request):
Expand Down Expand Up @@ -679,7 +694,7 @@ async def test_email_verification_default_backward_compatibility_suppress_error(
if not isinstance(s.recipe_implementation, SessionRecipeImplementation):
raise Exception("Should never come here")
response = await create_new_session(
s.recipe_implementation, "public", user_id, True, {}, {}, None
s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None
)

def api_side_effect(request: httpx.Request):
Expand Down Expand Up @@ -727,7 +742,9 @@ class CustomEmailService(
async def send_email(
self,
template_vars: emailverification.EmailTemplateVars,
user_context: Dict[str, Any],
user_context: Dict[
str, Any
], # pylint: disable=unused-argument, # pylint: disable=unused-argument
) -> None:
nonlocal email, email_verify_url
email = template_vars.user.email
Expand Down Expand Up @@ -762,7 +779,7 @@ async def send_email(
if not isinstance(s.recipe_implementation, SessionRecipeImplementation):
raise Exception("Should never come here")
response = await create_new_session(
s.recipe_implementation, "public", user_id, True, {}, {}, None
s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None
)

res = email_verify_token_request(
Expand Down Expand Up @@ -792,7 +809,8 @@ def email_delivery_override(
oi_send_email = oi.send_email

async def send_email(
template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any]
template_vars: VerificationEmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal email, email_verify_url
email = template_vars.user.email
Expand Down Expand Up @@ -833,7 +851,7 @@ async def send_email(
if not isinstance(s.recipe_implementation, SessionRecipeImplementation):
raise Exception("Should never come here")
response = await create_new_session(
s.recipe_implementation, "public", user_id, True, {}, {}, None
s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None
)

def api_side_effect(request: httpx.Request):
Expand Down Expand Up @@ -878,7 +896,9 @@ class CustomEmailDeliveryService(
async def send_email(
self,
template_vars: emailpassword.EmailTemplateVars,
user_context: Dict[str, Any],
user_context: Dict[
str, Any
], # pylint: disable=unused-argument, # pylint: disable=unused-argument
):
nonlocal email, password_reset_url
email = template_vars.user.email
Expand Down Expand Up @@ -925,7 +945,8 @@ async def test_email_verification_smtp_service(driver_config_client: TestClient)

def smtp_service_override(oi: SMTPServiceInterface[VerificationEmailTemplateVars]):
async def send_raw_email_override(
content: EmailContent, _user_context: Dict[str, Any]
content: EmailContent,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal send_raw_email_called, email
send_raw_email_called = True
Expand All @@ -937,7 +958,8 @@ async def send_raw_email_override(
# Note that we aren't calling oi.send_raw_email. So Transporter won't be used.

async def get_content_override(
template_vars: VerificationEmailTemplateVars, _user_context: Dict[str, Any]
template_vars: VerificationEmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
) -> EmailContent:
nonlocal get_content_called, email_verify_url
get_content_called = True
Expand Down Expand Up @@ -974,7 +996,8 @@ def email_delivery_override(
oi_send_email = oi.send_email

async def send_email_override(
template_vars: VerificationEmailTemplateVars, user_context: Dict[str, Any]
template_vars: VerificationEmailTemplateVars,
user_context: Dict[str, Any], # pylint: disable=unused-argument
):
nonlocal outer_override_called
outer_override_called = True
Expand Down Expand Up @@ -1013,7 +1036,7 @@ async def send_email_override(
if not isinstance(s.recipe_implementation, SessionRecipeImplementation):
raise Exception("Should never come here")
response = await create_new_session(
s.recipe_implementation, "public", user_id, True, {}, {}, None
s.recipe_implementation, "public", RecipeUserId(user_id), True, {}, {}, None
)

resp = email_verify_token_request(
Expand Down Expand Up @@ -1045,7 +1068,9 @@ class CustomEmailDeliveryService(
async def send_email(
self,
template_vars: PasswordResetEmailTemplateVars,
user_context: Dict[str, Any],
user_context: Dict[
str, Any
], # pylint: disable=unused-argument, # pylint: disable=unused-argument
):
nonlocal reset_url, token_info, tenant_info, query_length
password_reset_url = template_vars.password_reset_link
Expand Down Expand Up @@ -1081,7 +1106,9 @@ async def send_email(
dict_response = json.loads(response_1.text)
user_info = dict_response["user"]
assert dict_response["status"] == "OK"
resp = await send_reset_password_email("public", user_info["id"], "")
resp = await send_reset_password_email(
"public", user_info["id"], "[email protected]"
)
assert resp == "OK"

assert reset_url == "http://supertokens.io/auth/reset-password"
Expand Down Expand Up @@ -1118,9 +1145,13 @@ async def test_send_reset_password_email_invalid_input(
dict_response = json.loads(response_1.text)
user_info = dict_response["user"]

link = await send_reset_password_email("public", "invalidUserId", "")
link = await send_reset_password_email(
"public", "invalidUserId", "[email protected]"
)
assert link == "UNKNOWN_USER_ID_ERROR"

with raises(Exception) as err:
await send_reset_password_email("invalidTenantId", user_info["id"], "")
await send_reset_password_email(
"invalidTenantId", user_info["id"], "[email protected]"
)
assert "status code: 400" in str(err.value)
Loading

0 comments on commit 160af9b

Please sign in to comment.