diff --git a/tasks.py b/tasks.py index 5df00e3..2fc8f22 100644 --- a/tasks.py +++ b/tasks.py @@ -2,7 +2,7 @@ APP_PACKAGE = "deep_ice" -PACKAGES = f"{APP_PACKAGE} alembic" +PACKAGES = f"{APP_PACKAGE} alembic tests" # Helper function to run commands with 'uv run' and provide CI-friendly logging. diff --git a/tests/conftest.py b/tests/conftest.py index 719d5be..3e196e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,14 @@ +import asyncio +from typing import cast from unittest.mock import AsyncMock import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import ( + async_scoped_session, + async_sessionmaker, + create_async_engine, +) from sqlmodel import insert from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.pool import StaticPool @@ -29,8 +35,8 @@ def redis_client(mocker): ) -@pytest.fixture(name="session") -async def session_fixture(): +@pytest.fixture +async def _scoped_session(): async_engine = create_async_engine( "sqlite+aiosqlite://", connect_args={"check_same_thread": False}, @@ -38,11 +44,19 @@ async def session_fixture(): ) async with async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - - async_session = async_sessionmaker( + session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False ) - async with async_session() as session: + scoped_session = async_scoped_session( + session_factory, scopefunc=asyncio.current_task + ) + yield scoped_session + await scoped_session.remove() + + +@pytest.fixture +async def session(_scoped_session: async_scoped_session): + async with _scoped_session() as session: yield session @@ -89,53 +103,68 @@ async def initial_data(session: AsyncSession) -> dict: }, ] - await session.exec(insert(IceCream).values(icecream_dump)) + await session.exec(insert(IceCream).values(icecream_dump)) # type: ignore clean_users_dump = [ {key: value for key, value in user_data.items() if key != "password"} for user_data in users_dump ] - await session.exec(insert(User).values(clean_users_dump)) + await session.exec(insert(User).values(clean_users_dump)) # type: ignore await session.commit() return {"icecream": icecream_dump, "users": users_dump} -@pytest.fixture(name="client") -async def client_fixture(session: AsyncSession, mocker): - async def get_session_override(): - yield session +@pytest.fixture +async def users(session: AsyncSession, initial_data: dict) -> list[User]: + return list((await User.fetch(session)).all()) - app.state.redis_pool = mocker.AsyncMock() - app.dependency_overrides[get_async_session] = get_session_override - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://localhost" - ) as client: - yield client - app.dependency_overrides.clear() + +@pytest.fixture +async def user(users: list[User]) -> User: + return [usr for usr in users if usr.email == "cmin764@gmail.com"][0] -async def _get_users(session: AsyncSession, email: str | None = None) -> list[User]: - filters = [] - if email: - filters.append(User.email == email) - result = await User.fetch(session, filters=filters) - users = [result.one()] if email else result.all() - return users +@pytest.fixture +async def secondary_user(users: list[User], user: User) -> User: + return [usr for usr in users if usr.email != user.email][0] @pytest.fixture -async def users(session: AsyncSession, initial_data: dict) -> list[User]: - return await _get_users(session) +async def _client_factory(_scoped_session: async_scoped_session, mocker): + async def get_session_override(): + async with _scoped_session() as session: + yield session + + async def _create_client(): + app.state.redis_pool = mocker.AsyncMock() + app.dependency_overrides[get_async_session] = get_session_override + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://localhost" + ) as client: + yield client + app.dependency_overrides.clear() + + return _create_client @pytest.fixture -async def user(users: list[User]) -> User: - return [usr for usr in users if usr.email == "cmin764@gmail.com"][0] +async def client(_client_factory): + async for client in _client_factory(): + yield client + + +@pytest.fixture +async def secondary_client(_client_factory): + async for client in _client_factory(): + yield client + +async def _get_auth_token(initial_data: dict, user: User, client: AsyncClient) -> str: + users_dump = initial_data["users"] + user_dump = [item for item in users_dump if item["email"] == user.email][0] -async def _get_auth_token(client: AsyncClient, *, email: str, password: str) -> str: # Authenticate and get the token. - form_data = {"username": email, "password": password} + form_data = {"username": user.email, "password": user_dump["password"]} response = await client.post("/v1/auth/access-token", data=form_data) assert response.status_code == 200 @@ -146,60 +175,29 @@ async def _get_auth_token(client: AsyncClient, *, email: str, password: str) -> @pytest.fixture -async def auth_tokens( - initial_data: dict, users: list[User], client: AsyncClient -) -> list[dict[str, str]]: - users_dump = initial_data["users"] - tokens = [] - for user in users: - user_dump = [item for item in users_dump if item["email"] == user.email][0] - token = await _get_auth_token( - client, email=user.email, password=user_dump["password"] - ) - tokens.append({"email": user.email, "token": token}) - return tokens - - -@pytest.fixture -async def auth_token(initial_data: dict, user: User, client: AsyncClient) -> str: - users_dump = initial_data["users"] - user_dump = [item for item in users_dump if item["email"] == user.email][0] - return await _get_auth_token( - client, email=user.email, password=user_dump["password"] - ) - - -@pytest.fixture -async def auth_client(client: AsyncClient, auth_token: str): - client.headers.update({"Authorization": f"Bearer {auth_token}"}) +async def auth_client(initial_data: dict, user: User, client: AsyncClient): + token = await _get_auth_token(initial_data, user, client) + client.headers.update({"Authorization": f"Bearer {token}"}) return client @pytest.fixture async def secondary_auth_client( - user: User, client: AsyncClient, auth_tokens: list[dict[str, str]] + initial_data: dict, secondary_user: User, secondary_client: AsyncClient ): - token = [tkn for tkn in auth_tokens if tkn["email"] != user.email][0] - client.headers.update({"Authorization": f"Bearer {token}"}) - return client + token = await _get_auth_token(initial_data, secondary_user, secondary_client) + secondary_client.headers.update({"Authorization": f"Bearer {token}"}) + return secondary_client -@pytest.fixture -async def cart_items( - session: AsyncSession, initial_data: dict, user: User -) -> list[CartItem]: +async def _create_cart_with_items(session: AsyncSession, user: User) -> list[CartItem]: cart = Cart(user_id=user.id) session.add(cart) await session.commit() await session.refresh(cart) items = [] - for ice_data in initial_data["icecream"]: - icecream = ( - await IceCream.fetch( - session, filters=[IceCream.flavor == ice_data["flavor"]] - ) - ).one() + for icecream in (await IceCream.fetch(session)).all(): cart_item = CartItem( cart_id=cart.id, icecream_id=icecream.id, @@ -210,13 +208,34 @@ async def cart_items( session.add_all(items) await session.commit() + items = list( + ( + await CartItem.fetch( + session, + filters=[CartItem.id == cart.id], + joinedloads=[CartItem.icecream], + ) + ).all() + ) return items +@pytest.fixture +async def cart_items(session: AsyncSession, user: User) -> list[CartItem]: + return await _create_cart_with_items(session, user) + + +@pytest.fixture +async def secondary_cart_items( + session: AsyncSession, secondary_user: User +) -> list[CartItem]: + return await _create_cart_with_items(session, secondary_user) + + @pytest.fixture async def order(session: AsyncSession, cart_items: list[CartItem], user: User) -> Order: cart_service = CartService(session) - cart = await cart_service.get_cart(user.id) + cart = await cart_service.ensure_cart(cast(int, user.id)) order_service = OrderService(session, stats_service=stats_service) order = await order_service.make_order_from_cart(cart) await session.commit() diff --git a/tests/test_payments.py b/tests/test_payments.py index 2931b51..1e2c8c3 100644 --- a/tests/test_payments.py +++ b/tests/test_payments.py @@ -1,15 +1,10 @@ +import asyncio from unittest.mock import call import pytest from deep_ice import app -from deep_ice.models import ( - Order, - OrderItem, - OrderStatus, - PaymentMethod, - PaymentStatus, -) +from deep_ice.models import Order, OrderItem, OrderStatus, PaymentMethod, PaymentStatus async def _check_order_creation(session, order_id, *, status, amount): @@ -30,13 +25,15 @@ async def _check_order_creation(session, order_id, *, status, amount): return order +def _get_icecream(initial_data, flavor): + for icecream in initial_data["icecream"]: + if icecream["flavor"] == flavor: + return icecream + + def _check_quantities(order, initial_data): - # For confirmed orders, ensure the stock was deducted correctly. - get_icecream = lambda flavor: [ - ice for ice in initial_data["icecream"] if ice["flavor"] == flavor - ][0] for item in order.items: - before = get_icecream(item.icecream.flavor)["stock"] + before = _get_icecream(initial_data, item.icecream.flavor)["stock"] after = item.icecream.stock assert after == before - item.quantity assert not item.icecream.blocked_quantity @@ -115,16 +112,31 @@ async def test_payment_redirect_insufficient_stock( assert response.status_code == 307 redirect_url = response.headers.get("Location") assert redirect_url.endswith("/v1/cart") - await session.refresh(first_item) assert first_item.quantity != initial_quantity assert first_item.quantity == max_quantity @pytest.mark.anyio async def test_concurrent_card_payments( - redis_client, session, auth_client, secondary_auth_client, cart_items + redis_client, + session, + auth_client, + cart_items, + secondary_auth_client, + secondary_cart_items, ): - # Two different customers triggering a card payment simultaneously. + for items in (cart_items, secondary_cart_items): + for item in items: + item.quantity = item.icecream.available_stock + session.add_all(items) + await session.commit() + + # Two different customers triggering a card payment each, simultaneously and + # without enough stock for both. args, kwargs = ["/v1/payments"], {"json": {"method": PaymentMethod.CARD.value}} - main_response = await auth_client.post(*args, **kwargs) - secondary_response = await secondary_auth_client.post(*args, **kwargs) + requests = [ + crt_client.post(*args, **kwargs) + for crt_client in (auth_client, secondary_auth_client) + ] + responses = await asyncio.gather(*requests) + print([resp.status_code for resp in responses])