Skip to content

Commit

Permalink
Ensure base for testing concurrent payments
Browse files Browse the repository at this point in the history
  • Loading branch information
cmin764 committed Dec 3, 2024
1 parent eecc28d commit 8c7091c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 91 deletions.
2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


APP_PACKAGE = "deep_ice"
PACKAGES = f"{APP_PACKAGE} alembic"
PACKAGES = f"{APP_PACKAGE} alembic tests"


# Helper function to run commands with 'uv run' and provide CI-friendly logging.
Expand Down
165 changes: 92 additions & 73 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import asyncio
from typing import cast
from unittest.mock import AsyncMock

import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import (
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from sqlmodel import insert
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.pool import StaticPool
Expand All @@ -29,20 +35,28 @@ def redis_client(mocker):
)


@pytest.fixture(name="session")
async def session_fixture():
@pytest.fixture
async def _scoped_session():
async_engine = create_async_engine(
"sqlite+aiosqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

async_session = async_sessionmaker(
session_factory = async_sessionmaker(
bind=async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session() as session:
scoped_session = async_scoped_session(
session_factory, scopefunc=asyncio.current_task
)
yield scoped_session
await scoped_session.remove()


@pytest.fixture
async def session(_scoped_session: async_scoped_session):
async with _scoped_session() as session:
yield session


Expand Down Expand Up @@ -89,53 +103,68 @@ async def initial_data(session: AsyncSession) -> dict:
},
]

await session.exec(insert(IceCream).values(icecream_dump))
await session.exec(insert(IceCream).values(icecream_dump)) # type: ignore
clean_users_dump = [
{key: value for key, value in user_data.items() if key != "password"}
for user_data in users_dump
]
await session.exec(insert(User).values(clean_users_dump))
await session.exec(insert(User).values(clean_users_dump)) # type: ignore
await session.commit()

return {"icecream": icecream_dump, "users": users_dump}


@pytest.fixture(name="client")
async def client_fixture(session: AsyncSession, mocker):
async def get_session_override():
yield session
@pytest.fixture
async def users(session: AsyncSession, initial_data: dict) -> list[User]:
return list((await User.fetch(session)).all())

app.state.redis_pool = mocker.AsyncMock()
app.dependency_overrides[get_async_session] = get_session_override
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://localhost"
) as client:
yield client
app.dependency_overrides.clear()

@pytest.fixture
async def user(users: list[User]) -> User:
return [usr for usr in users if usr.email == "[email protected]"][0]


async def _get_users(session: AsyncSession, email: str | None = None) -> list[User]:
filters = []
if email:
filters.append(User.email == email)
result = await User.fetch(session, filters=filters)
users = [result.one()] if email else result.all()
return users
@pytest.fixture
async def secondary_user(users: list[User], user: User) -> User:
return [usr for usr in users if usr.email != user.email][0]


@pytest.fixture
async def users(session: AsyncSession, initial_data: dict) -> list[User]:
return await _get_users(session)
async def _client_factory(_scoped_session: async_scoped_session, mocker):
async def get_session_override():
async with _scoped_session() as session:
yield session

async def _create_client():
app.state.redis_pool = mocker.AsyncMock()
app.dependency_overrides[get_async_session] = get_session_override
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://localhost"
) as client:
yield client
app.dependency_overrides.clear()

return _create_client


@pytest.fixture
async def user(users: list[User]) -> User:
return [usr for usr in users if usr.email == "[email protected]"][0]
async def client(_client_factory):
async for client in _client_factory():
yield client


@pytest.fixture
async def secondary_client(_client_factory):
async for client in _client_factory():
yield client


async def _get_auth_token(initial_data: dict, user: User, client: AsyncClient) -> str:
users_dump = initial_data["users"]
user_dump = [item for item in users_dump if item["email"] == user.email][0]

async def _get_auth_token(client: AsyncClient, *, email: str, password: str) -> str:
# Authenticate and get the token.
form_data = {"username": email, "password": password}
form_data = {"username": user.email, "password": user_dump["password"]}
response = await client.post("/v1/auth/access-token", data=form_data)
assert response.status_code == 200

Expand All @@ -146,60 +175,29 @@ async def _get_auth_token(client: AsyncClient, *, email: str, password: str) ->


@pytest.fixture
async def auth_tokens(
initial_data: dict, users: list[User], client: AsyncClient
) -> list[dict[str, str]]:
users_dump = initial_data["users"]
tokens = []
for user in users:
user_dump = [item for item in users_dump if item["email"] == user.email][0]
token = await _get_auth_token(
client, email=user.email, password=user_dump["password"]
)
tokens.append({"email": user.email, "token": token})
return tokens


@pytest.fixture
async def auth_token(initial_data: dict, user: User, client: AsyncClient) -> str:
users_dump = initial_data["users"]
user_dump = [item for item in users_dump if item["email"] == user.email][0]
return await _get_auth_token(
client, email=user.email, password=user_dump["password"]
)


@pytest.fixture
async def auth_client(client: AsyncClient, auth_token: str):
client.headers.update({"Authorization": f"Bearer {auth_token}"})
async def auth_client(initial_data: dict, user: User, client: AsyncClient):
token = await _get_auth_token(initial_data, user, client)
client.headers.update({"Authorization": f"Bearer {token}"})
return client


@pytest.fixture
async def secondary_auth_client(
user: User, client: AsyncClient, auth_tokens: list[dict[str, str]]
initial_data: dict, secondary_user: User, secondary_client: AsyncClient
):
token = [tkn for tkn in auth_tokens if tkn["email"] != user.email][0]
client.headers.update({"Authorization": f"Bearer {token}"})
return client
token = await _get_auth_token(initial_data, secondary_user, secondary_client)
secondary_client.headers.update({"Authorization": f"Bearer {token}"})
return secondary_client


@pytest.fixture
async def cart_items(
session: AsyncSession, initial_data: dict, user: User
) -> list[CartItem]:
async def _create_cart_with_items(session: AsyncSession, user: User) -> list[CartItem]:
cart = Cart(user_id=user.id)
session.add(cart)
await session.commit()
await session.refresh(cart)

items = []
for ice_data in initial_data["icecream"]:
icecream = (
await IceCream.fetch(
session, filters=[IceCream.flavor == ice_data["flavor"]]
)
).one()
for icecream in (await IceCream.fetch(session)).all():
cart_item = CartItem(
cart_id=cart.id,
icecream_id=icecream.id,
Expand All @@ -210,13 +208,34 @@ async def cart_items(
session.add_all(items)
await session.commit()

items = list(
(
await CartItem.fetch(
session,
filters=[CartItem.id == cart.id],
joinedloads=[CartItem.icecream],
)
).all()
)
return items


@pytest.fixture
async def cart_items(session: AsyncSession, user: User) -> list[CartItem]:
return await _create_cart_with_items(session, user)


@pytest.fixture
async def secondary_cart_items(
session: AsyncSession, secondary_user: User
) -> list[CartItem]:
return await _create_cart_with_items(session, secondary_user)


@pytest.fixture
async def order(session: AsyncSession, cart_items: list[CartItem], user: User) -> Order:
cart_service = CartService(session)
cart = await cart_service.get_cart(user.id)
cart = await cart_service.ensure_cart(cast(int, user.id))
order_service = OrderService(session, stats_service=stats_service)
order = await order_service.make_order_from_cart(cart)
await session.commit()
Expand Down
46 changes: 29 additions & 17 deletions tests/test_payments.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import asyncio
from unittest.mock import call

import pytest

from deep_ice import app
from deep_ice.models import (
Order,
OrderItem,
OrderStatus,
PaymentMethod,
PaymentStatus,
)
from deep_ice.models import Order, OrderItem, OrderStatus, PaymentMethod, PaymentStatus


async def _check_order_creation(session, order_id, *, status, amount):
Expand All @@ -30,13 +25,15 @@ async def _check_order_creation(session, order_id, *, status, amount):
return order


def _get_icecream(initial_data, flavor):
for icecream in initial_data["icecream"]:
if icecream["flavor"] == flavor:
return icecream


def _check_quantities(order, initial_data):
# For confirmed orders, ensure the stock was deducted correctly.
get_icecream = lambda flavor: [
ice for ice in initial_data["icecream"] if ice["flavor"] == flavor
][0]
for item in order.items:
before = get_icecream(item.icecream.flavor)["stock"]
before = _get_icecream(initial_data, item.icecream.flavor)["stock"]
after = item.icecream.stock
assert after == before - item.quantity
assert not item.icecream.blocked_quantity
Expand Down Expand Up @@ -115,16 +112,31 @@ async def test_payment_redirect_insufficient_stock(
assert response.status_code == 307
redirect_url = response.headers.get("Location")
assert redirect_url.endswith("/v1/cart")
await session.refresh(first_item)
assert first_item.quantity != initial_quantity
assert first_item.quantity == max_quantity


@pytest.mark.anyio
async def test_concurrent_card_payments(
redis_client, session, auth_client, secondary_auth_client, cart_items
redis_client,
session,
auth_client,
cart_items,
secondary_auth_client,
secondary_cart_items,
):
# Two different customers triggering a card payment simultaneously.
for items in (cart_items, secondary_cart_items):
for item in items:
item.quantity = item.icecream.available_stock
session.add_all(items)
await session.commit()

# Two different customers triggering a card payment each, simultaneously and
# without enough stock for both.
args, kwargs = ["/v1/payments"], {"json": {"method": PaymentMethod.CARD.value}}
main_response = await auth_client.post(*args, **kwargs)
secondary_response = await secondary_auth_client.post(*args, **kwargs)
requests = [
crt_client.post(*args, **kwargs)
for crt_client in (auth_client, secondary_auth_client)
]
responses = await asyncio.gather(*requests)
print([resp.status_code for resp in responses])

0 comments on commit 8c7091c

Please sign in to comment.