diff --git a/.env.template b/.env.template index 02989aa..ca866e7 100644 --- a/.env.template +++ b/.env.template @@ -1,8 +1,8 @@ LOG_LEVEL=INFO +DEBUG=false # Postgres POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 POSTGRES_USER=deep POSTGRES_PASSWORD=icecream POSTGRES_DB=deep_ice diff --git a/Dockerfile b/Dockerfile index c274c58..9e3984a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,12 +4,12 @@ WORKDIR /app # Install uv deps with pip. RUN pip install uv COPY pyproject.toml . -RUN uv export --no-dev >requirements.txt && pip install -Ur requirements.txt +RUN uv pip install --system -Ur pyproject.toml -# Copy the rest of the application code. +# Copy the rest of the application code and install the project too. COPY . . +RUN uv pip install --system -e . +# Run the FastAPI app using uvicorn on default port. EXPOSE 80 - -# Command to run the FastAPI app using uvicorn. CMD ["fastapi", "run", "deep_ice", "--port", "80"] diff --git a/deep_ice/__init__.py b/deep_ice/__init__.py index db57a49..973501c 100644 --- a/deep_ice/__init__.py +++ b/deep_ice/__init__.py @@ -26,7 +26,7 @@ async def lifespan(fast_app: FastAPI): class TaskQueue: functions = [payment_service.make_payment_task] redis_settings = redis_settings - max_tries = settings.TASK_MAX_RETRIES + max_tries = settings.TASK_MAX_TRIES retry_delay = settings.TASK_RETRY_DELAY diff --git a/deep_ice/api/routes/cart.py b/deep_ice/api/routes/cart.py index 9351b93..c88939c 100644 --- a/deep_ice/api/routes/cart.py +++ b/deep_ice/api/routes/cart.py @@ -43,13 +43,14 @@ async def get_cart_items(current_user: CurrentUserDep, cart_service: CartService return cart -@router.post("/items", response_model=RetrieveCartItem) +@router.post( + "/items", response_model=RetrieveCartItem, status_code=status.HTTP_201_CREATED +) async def add_item_to_cart( session: SessionDep, current_user: CurrentUserDep, cart_service: CartServiceDep, item: Annotated[CreateCartItem, Body()], - response: Response, ): cart = await cart_service.ensure_cart(cast(int, current_user.id)) cart_item = CartItem(cart_id=cart.id, **item.model_dump()) @@ -65,7 +66,6 @@ async def add_item_to_cart( else: cart_item.icecream = icecream - response.status_code = status.HTTP_201_CREATED return cart_item diff --git a/deep_ice/api/routes/payments.py b/deep_ice/api/routes/payments.py index 45f9a9a..3a10625 100644 --- a/deep_ice/api/routes/payments.py +++ b/deep_ice/api/routes/payments.py @@ -1,14 +1,20 @@ from typing import Annotated, cast import sentry_sdk +from aioredlock import LockError from fastapi import APIRouter, Body, HTTPException, Request, Response, status from fastapi.responses import RedirectResponse from sqlalchemy.exc import SQLAlchemyError +from sqlmodel.ext.asyncio.session import AsyncSession from deep_ice.core import logger -from deep_ice.core.dependencies import CurrentUserDep, SessionDep -from deep_ice.models import PaymentMethod, PaymentStatus, RetrievePayment -from deep_ice.services.cart import CartService +from deep_ice.core.dependencies import ( + CartServiceDep, + CurrentUserDep, + RedlockDep, + SessionDep, +) +from deep_ice.models import Cart, Payment, PaymentMethod, PaymentStatus, RetrievePayment from deep_ice.services.order import OrderService from deep_ice.services.payment import PaymentError, PaymentService, payment_stub from deep_ice.services.stats import stats_service @@ -16,30 +22,9 @@ router = APIRouter() -@router.post("", response_model=RetrievePayment) -async def make_payment( - session: SessionDep, - current_user: CurrentUserDep, - method: Annotated[PaymentMethod, Body(embed=True)], - request: Request, - response: Response, -): - # FIXME(cmin764): Check if we need an async Lock primitive here in order to allow - # only one user to submit an order at a time. (based on available stock check) - cart_service = CartService(session) - cart = await cart_service.get_cart(cast(int, current_user.id)) - if not cart or not cart.items: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="There are no items in the cart", - ) - - cart_ok = await cart_service.check_items_against_stock(cart) - if not cart_ok: - # Redirect back to the cart so we get aware of the new state based on the - # available stock. And let the user decide if it continues with a payment. - return RedirectResponse(url=request.url_for("get_cart_items")) - +async def _make_payment( + session: AsyncSession, *, cart: Cart, method: PaymentMethod, response: Response +) -> Payment: # Items are available and ready to be sold, make the order and pay for it. order_service = OrderService(session, stats_service=stats_service) payment_service = PaymentService( @@ -62,14 +47,56 @@ async def make_payment( raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Payment failed" ) - else: - await session.commit() - response.status_code = ( - status.HTTP_202_ACCEPTED - if payment.status == PaymentStatus.PENDING - else status.HTTP_201_CREATED + + await session.commit() + response.status_code = ( + status.HTTP_202_ACCEPTED + if payment.status == PaymentStatus.PENDING + else status.HTTP_201_CREATED + ) + return payment + + +@router.post("", response_model=RetrievePayment) +async def make_payment( + session: SessionDep, + current_user: CurrentUserDep, + cart_service: CartServiceDep, + redlock: RedlockDep, + method: Annotated[PaymentMethod, Body(embed=True)], + request: Request, + response: Response, +): + cart = await cart_service.get_cart(cast(int, current_user.id)) + if not cart or not cart.items: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="There are no items in the cart", ) - return payment + + lock_keys = [f"ice-lock:{item.icecream_id}" for item in cart.items] + locks = [] + try: + for lock_key in lock_keys: + lock = await redlock.lock(lock_key) + locks.append(lock) + + cart_ok = await cart_service.check_items_against_stock(cart) + if not cart_ok: + # Redirect back to the cart so we get aware of the new state based on + # the available stock. And let the user decide if it continues with a + # payment. + return RedirectResponse(url=request.url_for("get_cart_items")) + + return await _make_payment( + session, cart=cart, method=method, response=response + ) + except LockError as exc: + logger.exception("Payment lock error with key %r: %s", lock_key, exc) + sentry_sdk.capture_exception(exc) + finally: + for lock in locks: + await redlock.unlock(lock) @router.get("", response_model=list[RetrievePayment]) diff --git a/deep_ice/core/config.py b/deep_ice/core/config.py index fd72c26..857480e 100644 --- a/deep_ice/core/config.py +++ b/deep_ice/core/config.py @@ -10,11 +10,12 @@ class Settings(BaseSettings): model_config = SettingsConfigDict( # Use the top level .env file (one level above ./deep_ice/). env_file=".env", - env_ignore_empty=True, + env_ignore_empty=False, extra="ignore", ) LOG_LEVEL: str = "INFO" + DEBUG: bool = False PROJECT_NAME: str = "Deep Ice" API_V1_STR: str = "/v1" @@ -30,8 +31,10 @@ class Settings(BaseSettings): POSTGRES_DB: str REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + REDLOCK_TTL: int = 30 # seconds for the lock to persists in Redis - TASK_MAX_RETRIES: int = 3 + TASK_MAX_TRIES: int = 3 TASK_RETRY_DELAY: int = 1 # seconds between retries TASK_BACKOFF_FACTOR: int = 5 # seconds to wait based on the job try counter @@ -52,4 +55,4 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: settings = Settings() # type: ignore -redis_settings = RedisSettings(host=settings.REDIS_HOST) +redis_settings = RedisSettings(host=settings.REDIS_HOST, port=settings.REDIS_PORT) diff --git a/deep_ice/core/database.py b/deep_ice/core/database.py index 8cf9b94..ab386ad 100644 --- a/deep_ice/core/database.py +++ b/deep_ice/core/database.py @@ -6,7 +6,7 @@ from deep_ice.core.config import settings async_engine = create_async_engine( - str(settings.SQLALCHEMY_DATABASE_URI), echo=True, future=True + str(settings.SQLALCHEMY_DATABASE_URI), echo=settings.DEBUG, future=True ) diff --git a/deep_ice/core/dependencies.py b/deep_ice/core/dependencies.py index 051e1de..2c7bdc7 100644 --- a/deep_ice/core/dependencies.py +++ b/deep_ice/core/dependencies.py @@ -1,7 +1,8 @@ -from typing import Annotated +from typing import Annotated, AsyncGenerator import jwt import sentry_sdk +from aioredlock import Aioredlock from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError @@ -9,7 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from deep_ice.core import logger, security -from deep_ice.core.config import settings +from deep_ice.core.config import redis_settings, settings from deep_ice.core.database import get_async_session from deep_ice.models import TokenPayload, User from deep_ice.services.cart import CartService @@ -51,5 +52,15 @@ async def get_cart_service(session: SessionDep) -> CartService: return CartService(session) +async def get_lock_manager() -> AsyncGenerator[Aioredlock, None]: + lock_manager = Aioredlock( + [{"host": redis_settings.host, "port": redis_settings.port}], + internal_lock_timeout=settings.REDLOCK_TTL, + ) + yield lock_manager + await lock_manager.destroy() + + CurrentUserDep = Annotated[User, Depends(get_current_user)] CartServiceDep = Annotated[CartService, Depends(get_cart_service)] +RedlockDep = Annotated[Aioredlock, Depends(get_lock_manager)] diff --git a/deep_ice/models.py b/deep_ice/models.py index 2a2f01a..97e51ee 100644 --- a/deep_ice/models.py +++ b/deep_ice/models.py @@ -68,7 +68,7 @@ class BaseIceCream(SQLModel): class IceCream(BaseIceCream, FetchMixin, table=True): id: Annotated[int | None, Field(primary_key=True)] = None stock: int - blocked_quantity: int = 0 # reserved for payments only + blocked_quantity: int = 0 # reserved during payments is_active: bool = True cart_items: list["CartItem"] = Relationship( diff --git a/deep_ice/services/cart.py b/deep_ice/services/cart.py index 548101a..1d0dcdf 100644 --- a/deep_ice/services/cart.py +++ b/deep_ice/services/cart.py @@ -34,8 +34,16 @@ async def ensure_cart(self, user_id: int) -> Cart: return cart + async def _refresh_icecream_stock(self, cart: Cart) -> None: + await self._session.refresh(cart) + for cart_item in cart.items: + await self._session.refresh(cart_item) + cart_item.icecream = await cart_item.awaitable_attrs.icecream + await self._session.refresh(cart_item.icecream) + async def check_items_against_stock(self, cart: Cart) -> bool: # Ensure once again that we still have on stock the items we intend to buy. + await self._refresh_icecream_stock(cart) cart_ok = True for item in cart.items: if item.quantity > item.icecream.available_stock: diff --git a/deep_ice/services/order.py b/deep_ice/services/order.py index 2d962db..e7b6f80 100644 --- a/deep_ice/services/order.py +++ b/deep_ice/services/order.py @@ -43,12 +43,11 @@ async def confirm_order(self, order_id: int): icecream.stock -= item.quantity icecream.blocked_quantity -= item.quantity + self._session.add(icecream) await self._stats_service.acknowledge_icecream_demand( cast(int, icecream.id), name=icecream.name, quantity=item.quantity ) - self._session.add_all(order.items) - async def cancel_order(self, order_id: int): order = await self._get_order(order_id) order.status = OrderStatus.CANCELLED @@ -62,7 +61,7 @@ async def cancel_order(self, order_id: int): continue icecream.blocked_quantity -= item.quantity - self._session.add_all(order.items) + self._session.add(icecream) async def make_order_from_cart(self, cart: Cart) -> Order: # Creates and saves an order out of the current cart and returns it for later diff --git a/deep_ice/services/payment.py b/deep_ice/services/payment.py index 24f7b34..9211112 100644 --- a/deep_ice/services/payment.py +++ b/deep_ice/services/payment.py @@ -33,7 +33,7 @@ async def make_payment_task( msg = f"{method.value} payment for order #{order_id} failed, retrying..." logger.warning(msg) sentry_sdk.capture_message(msg, level="warning") - raise Retry(defer=attempts * settings.TASK_BACKOFF_FACTOR) + raise Retry(defer=attempts * settings.TASK_BACKOFF_FACTOR) async for session in get_async_session(): order_service = OrderService(session, stats_service=stats_service) @@ -82,7 +82,9 @@ class PaymentStub(PaymentInterface): min_delay: int max_delay: int - allow_failures: bool = False # enable failures or not + # Enable failures (or not) and at what rate. + allow_failures: bool = False + failure_rate: float = 0.2 async def make_payment( self, @@ -124,7 +126,9 @@ async def make_payment( if self.allow_failures: # Simulate payment result: 80% chance of success, 20% chance of failure. payment_result = random.choices( - [PaymentStatus.SUCCESS, PaymentStatus.FAILED], weights=[80, 20], k=1 + [PaymentStatus.SUCCESS, PaymentStatus.FAILED], + weights=[1 - self.failure_rate, self.failure_rate], + k=1, )[0] else: payment_result = PaymentStatus.SUCCESS @@ -196,4 +200,4 @@ async def set_order_payment_status(self, order_id: int, status: PaymentStatus): self._session.add(payment) -payment_stub = PaymentStub(1, 3, allow_failures=True) +payment_stub = PaymentStub(1, 3, allow_failures=True, failure_rate=0.2) diff --git a/deep_ice/services/stats.py b/deep_ice/services/stats.py index 7c6292e..eda2b2b 100644 --- a/deep_ice/services/stats.py +++ b/deep_ice/services/stats.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -import redis.asyncio as redis +import redis.asyncio as aioredis from deep_ice.core.config import redis_settings @@ -22,7 +22,9 @@ class StatsService(StatsInterface): POPULARITY_KEY = "POPULAR_ICECREAM" def __init__(self): - self._client = redis.Redis(host=redis_settings.host) + self._client = aioredis.Redis( + host=redis_settings.host, port=redis_settings.port + ) @staticmethod def _get_product_key(*args: int | str) -> str: diff --git a/pyproject.toml b/pyproject.toml index ccc34da..71f172a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ dependencies = [ "arq>=0.26.1", "redis>=5.1.1", "sentry-sdk[arq,fastapi]>=2.18.0", + "aioredlock>=0.7.3", + "setuptools>=75.6.0", ] [tool.uv] @@ -32,8 +34,12 @@ dev-dependencies = [ "pytest>=8.3.3", "ruff>=0.6.9", "types-passlib>=1.7.7.20240819", + "pytest-env>=1.1.5", ] +[tool.setuptools] +packages = ["deep_ice"] + [tool.pytest.ini_options] # Explicitly set the loop scope for asyncio fixtures to avoid the deprecation warning asyncio_default_fixture_loop_scope = "function" @@ -43,6 +49,11 @@ filterwarnings = [ "ignore::UserWarning", ] addopts = "--disable-warnings" +env = [ + "SENTRY_DSN = ", # disable Sentry reporting during testing + "LOG_LEVEL = INFO", + "DEBUG = true" +] [tool.flake8] # Check that this is aligned with your other tools like Black @@ -68,3 +79,6 @@ line-length = 88 [tool.ruff] line-length = 88 + +[tool.mypy] +ignore_missing_imports = true diff --git a/tasks.py b/tasks.py index 264df89..2fc8f22 100644 --- a/tasks.py +++ b/tasks.py @@ -2,7 +2,7 @@ APP_PACKAGE = "deep_ice" -PACKAGES = f"{APP_PACKAGE} alembic" +PACKAGES = f"{APP_PACKAGE} alembic tests" # Helper function to run commands with 'uv run' and provide CI-friendly logging. @@ -114,5 +114,5 @@ def run_worker(ctx, develop: bool = False): """Run a worker for processing the task queue in production or development mode.""" params = "" if develop: - params = "--watch deep_ice" + params = f"--watch {APP_PACKAGE}" uv_run(ctx, f"arq {APP_PACKAGE}.TaskQueue {params}", "Task queue worker") diff --git a/tests/conftest.py b/tests/conftest.py index d231710..858599e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,28 @@ +import asyncio +import functools +from typing import cast from unittest.mock import AsyncMock import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import ( + async_scoped_session, + async_sessionmaker, + close_all_sessions, + create_async_engine, +) +from sqlmodel import insert +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlmodel.pool import StaticPool + from deep_ice import app from deep_ice.core.database import get_async_session +from deep_ice.core.dependencies import get_lock_manager from deep_ice.core.security import get_password_hash from deep_ice.models import Cart, CartItem, IceCream, Order, SQLModel, User from deep_ice.services.cart import CartService from deep_ice.services.order import OrderService from deep_ice.services.stats import stats_service -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlmodel import insert -from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodel.pool import StaticPool # Run tests with `asyncio` only. @@ -28,8 +38,8 @@ def redis_client(mocker): ) -@pytest.fixture(name="session") -async def session_fixture(): +@pytest.fixture +async def _scoped_session_factory(): async_engine = create_async_engine( "sqlite+aiosqlite://", connect_args={"check_same_thread": False}, @@ -37,11 +47,21 @@ async def session_fixture(): ) async with async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - - async_session = async_sessionmaker( + session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False ) - async with async_session() as session: + scoped_session_factory = async_scoped_session( + session_factory, scopefunc=asyncio.current_task + ) + yield scoped_session_factory + await close_all_sessions() + await scoped_session_factory.remove() + await async_engine.dispose() + + +@pytest.fixture +async def session(_scoped_session_factory: async_scoped_session): + async with _scoped_session_factory() as session: yield session @@ -71,45 +91,107 @@ async def initial_data(session: AsyncSession) -> dict: { "name": "Cosmin Poieana", "email": "cmin764@gmail.com", + "password": "cosmin-password", "hashed_password": get_password_hash("cosmin-password"), }, { "name": "John Doe", "email": "john.doe@deepicecream.ai", + "password": "john-password", "hashed_password": get_password_hash("john-password"), }, { "name": "Sam Smith", "email": "sam.smith@deepicecream.ai", + "password": "sam-password", "hashed_password": get_password_hash("sam-password"), }, ] - await session.exec(insert(IceCream).values(icecream_dump)) - await session.exec(insert(User).values(users_dump)) + await session.exec(insert(IceCream).values(icecream_dump)) # type: ignore + clean_users_dump = [ + {key: value for key, value in user_data.items() if key != "password"} + for user_data in users_dump + ] + await session.exec(insert(User).values(clean_users_dump)) # type: ignore await session.commit() return {"icecream": icecream_dump, "users": users_dump} -@pytest.fixture(name="client") -async def client_fixture(session: AsyncSession, mocker): - async def get_session_override(): - yield session +@pytest.fixture +async def users(session: AsyncSession, initial_data: dict) -> list[User]: + return list((await User.fetch(session)).all()) + + +@pytest.fixture +async def user(users: list[User]) -> User: + return [usr for usr in users if usr.email == "cmin764@gmail.com"][0] + + +@pytest.fixture +async def secondary_user(users: list[User], user: User) -> User: + return [usr for usr in users if usr.email != user.email][0] + + +class AsyncLockManager: + @staticmethod + @functools.lru_cache(maxsize=3) # number of ice cream flavors + def _get_lock(key: str, loop): # cache by key and current event loop + return asyncio.Lock() - app.state.redis_pool = mocker.AsyncMock() - app.dependency_overrides[get_async_session] = get_session_override - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://localhost" - ) as client: + @classmethod + async def lock(cls, key: str): + current_loop = asyncio.get_running_loop() + lock = cls._get_lock(key, loop=current_loop) + await lock.acquire() + return lock + + @staticmethod + async def unlock(lock: asyncio.Lock): + lock.release() + + +@pytest.fixture +async def _client_factory(_scoped_session_factory: async_scoped_session, mocker): + async def _get_async_session_override(): + async with _scoped_session_factory() as session: + yield session + + async def _get_lock_manager_override(): + return AsyncLockManager() + + async def _create_client(): + app.state.redis_pool = mocker.AsyncMock() + app.dependency_overrides[get_async_session] = _get_async_session_override + app.dependency_overrides[get_lock_manager] = _get_lock_manager_override + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://localhost" + ) as client: + yield client + app.dependency_overrides.clear() + + return _create_client + + +@pytest.fixture +async def client(_client_factory): + async for client in _client_factory(): yield client - app.dependency_overrides.clear() @pytest.fixture -async def auth_token(initial_data: dict, client: AsyncClient) -> str: +async def secondary_client(_client_factory): + async for client in _client_factory(): + yield client + + +async def _get_auth_token(initial_data: dict, user: User, client: AsyncClient) -> str: + users_dump = initial_data["users"] + user_dump = [item for item in users_dump if item["email"] == user.email][0] + # Authenticate and get the token. - form_data = {"username": "cmin764@gmail.com", "password": "cosmin-password"} + form_data = {"username": user.email, "password": user_dump["password"]} response = await client.post("/v1/auth/access-token", data=form_data) assert response.status_code == 200 @@ -120,52 +202,58 @@ async def auth_token(initial_data: dict, client: AsyncClient) -> str: @pytest.fixture -async def auth_client(client: AsyncClient, auth_token: str): - client.headers.update({"Authorization": f"Bearer {auth_token}"}) +async def auth_client(initial_data: dict, user: User, client: AsyncClient): + token = await _get_auth_token(initial_data, user, client) + client.headers.update({"Authorization": f"Bearer {token}"}) return client @pytest.fixture -async def user(session: AsyncSession) -> User: - user = ( - await User.fetch(session, filters=[User.email == "cmin764@gmail.com"]) - ).one() - return user +async def secondary_auth_client( + initial_data: dict, secondary_user: User, secondary_client: AsyncClient +): + token = await _get_auth_token(initial_data, secondary_user, secondary_client) + secondary_client.headers.update({"Authorization": f"Bearer {token}"}) + return secondary_client -@pytest.fixture -async def cart_items( - session: AsyncSession, initial_data: dict, user: User -) -> list[CartItem]: +async def _create_cart_with_items(session: AsyncSession, user: User) -> list[CartItem]: cart = Cart(user_id=user.id) session.add(cart) await session.commit() await session.refresh(cart) items = [] - for ice_data in initial_data["icecream"]: - icecream = ( - await IceCream.fetch( - session, filters=[IceCream.flavor == ice_data["flavor"]] - ) - ).one() + for icecream in (await IceCream.fetch(session)).all(): cart_item = CartItem( cart_id=cart.id, - icecream_id=icecream.id, + icecream=icecream, quantity=icecream.available_stock // 10, ) items.append(cart_item) - session.add_all(items) await session.commit() - return items + await session.refresh(cart) + return await cart.awaitable_attrs.items + + +@pytest.fixture +async def cart_items(session: AsyncSession, user: User) -> list[CartItem]: + return await _create_cart_with_items(session, user) + + +@pytest.fixture +async def secondary_cart_items( + session: AsyncSession, secondary_user: User +) -> list[CartItem]: + return await _create_cart_with_items(session, secondary_user) @pytest.fixture async def order(session: AsyncSession, cart_items: list[CartItem], user: User) -> Order: cart_service = CartService(session) - cart = await cart_service.get_cart(user.id) + cart = await cart_service.ensure_cart(cast(int, user.id)) order_service = OrderService(session, stats_service=stats_service) order = await order_service.make_order_from_cart(cart) await session.commit() diff --git a/tests/test_payments.py b/tests/test_payments.py index 4ad8381..ad9298a 100644 --- a/tests/test_payments.py +++ b/tests/test_payments.py @@ -1,3 +1,5 @@ +import asyncio +import itertools from unittest.mock import call import pytest @@ -24,13 +26,15 @@ async def _check_order_creation(session, order_id, *, status, amount): return order +def _get_icecream(initial_data, flavor): + for icecream in initial_data["icecream"]: + if icecream["flavor"] == flavor: + return icecream + + def _check_quantities(order, initial_data): - # For confirmed orders, ensure the stock was deducted correctly. - get_icecream = lambda flavor: [ - ice for ice in initial_data["icecream"] if ice["flavor"] == flavor - ][0] for item in order.items: - before = get_icecream(item.icecream.flavor)["stock"] + before = _get_icecream(initial_data, item.icecream.flavor)["stock"] after = item.icecream.stock assert after == before - item.quantity assert not item.icecream.blocked_quantity @@ -85,3 +89,113 @@ async def test_make_successful_payment( data = response.json() assert len(data) == 1 assert data[0]["method"] == method.value + + +@pytest.mark.anyio +async def test_payment_empty_cart(redis_client, session, auth_client): + response = await auth_client.post( + "/v1/payments", json={"method": PaymentMethod.CASH.value} + ) + assert response.status_code == 404 + + +@pytest.mark.anyio +async def test_payment_redirect_insufficient_stock( + redis_client, session, auth_client, cart_items +): + # Simulate purchase of some icecream which became unavailable in the meantime. + first_item = cart_items[0] + icecream = await first_item.awaitable_attrs.icecream + initial_quantity = first_item.quantity + max_quantity = initial_quantity - 1 + icecream.stock = max_quantity + session.add(icecream) + await session.commit() + + # Simulate payment and check if the redirect happened including the cart item + # quantity update to the new maximum available quantity for that ice cream flavor. + response = await auth_client.post( + "/v1/payments", json={"method": PaymentMethod.CASH.value} + ) + assert response.status_code == 307 + redirect_url = response.headers.get("Location") + assert redirect_url.endswith("/v1/cart") + assert first_item.quantity != initial_quantity + assert first_item.quantity == max_quantity + + +async def _clients_requests(path, *, _clients, _method, _payloads=None, **payload): + paths = path if isinstance(path, list | tuple) else [path] + payloads = (_payloads or []) + [payload] + requests = [ + getattr(client, _method)(path, json=payload) + for client in _clients + for path, payload in zip(paths, payloads) + ] + responses = await asyncio.gather(*requests) + return responses + + +@pytest.mark.parametrize("quantity_factor", [1, 0.5]) +@pytest.mark.parametrize("method", list(PaymentMethod)) +@pytest.mark.anyio +async def test_concurrent_payments( + redis_client, + session, + auth_client, + cart_items, + secondary_auth_client, + secondary_cart_items, + method, + quantity_factor, +): + # Two greedy customers add the whole stock (or a part of it) at the same time to + # each of their carts. + requests_tasks = [] + for selected_client, items in ( + (auth_client, cart_items), + (secondary_auth_client, secondary_cart_items), + ): + paths, payloads = [], [] + for item in items: + paths.append(f"/v1/cart/items/{item.id}") + payloads.append( + {"quantity": int(item.icecream.available_stock * quantity_factor)} + ) + requests_tasks.append( + _clients_requests( + paths, + _clients=[selected_client], + _method="put", + _payloads=payloads, + ) + ) + responses = itertools.chain(*(await asyncio.gather(*requests_tasks))) + for response in responses: + assert response.status_code == 200 + + # Two different customers triggering a payment each, simultaneously and potentially + # without enough stock for both. + responses = await _clients_requests( + "/v1/payments", + _clients=[auth_client, secondary_auth_client], + _method="post", + method=method.value, + ) + one_code = 201 if method is PaymentMethod.CASH else 202 + other_code = 307 if quantity_factor > 0.5 else one_code + assert sorted(response.status_code for response in responses) == sorted( + [one_code, other_code] + ) + + # Only one of the two can successfully order when trying to overbuy. + orders = (await Order.fetch(session, joinedloads=[Order.items])).unique().all() + expected_orders_count = 1 if quantity_factor > 0.5 else 2 + assert len(orders) == expected_orders_count + expected_status = ( + OrderStatus.CONFIRMED if method is PaymentMethod.CASH else OrderStatus.PENDING + ) + expected_amount = 1110.0 * quantity_factor + for order in orders: + assert order.status is expected_status + assert order.amount == expected_amount diff --git a/uv.lock b/uv.lock index 7e119d8..c33e5dd 100644 --- a/uv.lock +++ b/uv.lock @@ -5,6 +5,32 @@ resolution-markers = [ "python_full_version >= '3.13'", ] +[[package]] +name = "aioredis" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout" }, + { name = "hiredis" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2c/2a/662e5e79dde5d00964b995d50e38ecdefeeeb09b37edafff193c7e850f46/aioredis-1.3.1.tar.gz", hash = "sha256:15f8af30b044c771aee6787e5ec24694c048184c7b9e54c3b60c750a4b93273a", size = 155577 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/64/1b1612d0a104f21f80eb4c6e1b6075f2e6aba8e228f46f229cfd3fdac859/aioredis-1.3.1-py3-none-any.whl", hash = "sha256:b61808d7e97b7cd5a92ed574937a079c9387fdadd22bfbfa7ad2fd319ecc26e3", size = 65259 }, +] + +[[package]] +name = "aioredlock" +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aioredis" }, + { name = "attrs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/6b/1e8ab48dbcfe802d1d07dece32bb5eea02bd494e0e3e5d8e8629c136d9ca/aioredlock-0.7.3.tar.gz", hash = "sha256:903727b26eb571c926018a8ae2b754c6c11861996410e3c4f1309872d2545440", size = 11923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/67/ac617ecad2cbf12e639e9ed50f2ef956e415a8c0b47e6442017ce25c44b4/aioredlock-0.7.3-py3-none-any.whl", hash = "sha256:7432fe17cf2ce55292409f4e80d26af5ccbf1a09aa4566e30bcfc5dabd4b3e1f", size = 12903 }, +] + [[package]] name = "aiosqlite" version = "0.20.0" @@ -66,6 +92,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/9d/0fc5a5a08453b0c05c118ced6c62720063a1bde60d10ff579611c29b25cb/arq-0.26.1-py3-none-any.whl", hash = "sha256:789d12ca7d69919bd2e641e44f3f14a38bd854daa7ded22cdd725796e8c65352", size = 25891 }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, +] + [[package]] name = "asyncpg" version = "0.29.0" @@ -82,6 +117,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/86/7a18e1a457afb73991e5e5586e2341af09a31c91d8f65cc003f0b4553252/asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe", size = 530253 }, ] +[[package]] +name = "attrs" +version = "24.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/48/c8/6260f8ccc11f0917360fc0da435c5c9c7504e3db174d5a12a1494887b045/attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff", size = 805984 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 }, +] + [[package]] name = "bcrypt" version = "4.2.0" @@ -171,6 +215,7 @@ name = "deep-ice" version = "1.3.0" source = { virtual = "." } dependencies = [ + { name = "aioredlock" }, { name = "alembic" }, { name = "arq" }, { name = "asyncpg" }, @@ -182,6 +227,7 @@ dependencies = [ { name = "pyjwt" }, { name = "redis" }, { name = "sentry-sdk", extra = ["arq", "fastapi"] }, + { name = "setuptools" }, { name = "sqlmodel" }, ] @@ -195,6 +241,7 @@ dev = [ { name = "mypy" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-env" }, { name = "pytest-mock" }, { name = "ruff" }, { name = "types-passlib" }, @@ -202,6 +249,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aioredlock", specifier = ">=0.7.3" }, { name = "alembic", specifier = ">=1.13.3" }, { name = "arq", specifier = ">=0.26.1" }, { name = "asyncpg", specifier = ">=0.29.0" }, @@ -213,6 +261,7 @@ requires-dist = [ { name = "pyjwt", specifier = ">=2.9.0" }, { name = "redis", specifier = ">=5.1.1" }, { name = "sentry-sdk", extras = ["arq", "fastapi"], specifier = ">=2.18.0" }, + { name = "setuptools", specifier = ">=75.6.0" }, { name = "sqlmodel", specifier = ">=0.0.22" }, ] @@ -226,6 +275,7 @@ dev = [ { name = "mypy", specifier = ">=1.12.0" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, + { name = "pytest-env", specifier = ">=1.1.5" }, { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "ruff", specifier = ">=0.6.9" }, { name = "types-passlib", specifier = ">=1.7.7.20240819" }, @@ -762,6 +812,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 }, ] +[[package]] +name = "pytest-env" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1f/31/27f28431a16b83cab7a636dce59cf397517807d247caa38ee67d65e71ef8/pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf", size = 8911 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/b8/87cfb16045c9d4092cfcf526135d73b88101aac83bc1adcf82dfb5fd3833/pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30", size = 6141 }, +] + [[package]] name = "pytest-mock" version = "3.14.0" @@ -891,6 +953,15 @@ fastapi = [ { name = "fastapi" }, ] +[[package]] +name = "setuptools" +version = "75.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/54/292f26c208734e9a7f067aea4a7e282c080750c4546559b58e2e45413ca0/setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6", size = 1337429 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/21/47d163f615df1d30c094f6c8bbb353619274edccf0327b185cc2493c2c33/setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d", size = 1224032 }, +] + [[package]] name = "shellingham" version = "1.5.4"