-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ensure base for testing concurrent payments
- Loading branch information
Showing
3 changed files
with
122 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
import asyncio | ||
from typing import cast | ||
from unittest.mock import AsyncMock | ||
|
||
import pytest | ||
from httpx import ASGITransport, AsyncClient | ||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine | ||
from sqlalchemy.ext.asyncio import ( | ||
async_scoped_session, | ||
async_sessionmaker, | ||
create_async_engine, | ||
) | ||
from sqlmodel import insert | ||
from sqlmodel.ext.asyncio.session import AsyncSession | ||
from sqlmodel.pool import StaticPool | ||
|
@@ -29,20 +35,28 @@ def redis_client(mocker): | |
) | ||
|
||
|
||
@pytest.fixture(name="session") | ||
async def session_fixture(): | ||
@pytest.fixture | ||
async def _scoped_session(): | ||
async_engine = create_async_engine( | ||
"sqlite+aiosqlite://", | ||
connect_args={"check_same_thread": False}, | ||
poolclass=StaticPool, | ||
) | ||
async with async_engine.begin() as conn: | ||
await conn.run_sync(SQLModel.metadata.create_all) | ||
|
||
async_session = async_sessionmaker( | ||
session_factory = async_sessionmaker( | ||
bind=async_engine, class_=AsyncSession, expire_on_commit=False | ||
) | ||
async with async_session() as session: | ||
scoped_session = async_scoped_session( | ||
session_factory, scopefunc=asyncio.current_task | ||
) | ||
yield scoped_session | ||
await scoped_session.remove() | ||
|
||
|
||
@pytest.fixture | ||
async def session(_scoped_session: async_scoped_session): | ||
async with _scoped_session() as session: | ||
yield session | ||
|
||
|
||
|
@@ -89,53 +103,68 @@ async def initial_data(session: AsyncSession) -> dict: | |
}, | ||
] | ||
|
||
await session.exec(insert(IceCream).values(icecream_dump)) | ||
await session.exec(insert(IceCream).values(icecream_dump)) # type: ignore | ||
clean_users_dump = [ | ||
{key: value for key, value in user_data.items() if key != "password"} | ||
for user_data in users_dump | ||
] | ||
await session.exec(insert(User).values(clean_users_dump)) | ||
await session.exec(insert(User).values(clean_users_dump)) # type: ignore | ||
await session.commit() | ||
|
||
return {"icecream": icecream_dump, "users": users_dump} | ||
|
||
|
||
@pytest.fixture(name="client") | ||
async def client_fixture(session: AsyncSession, mocker): | ||
async def get_session_override(): | ||
yield session | ||
@pytest.fixture | ||
async def users(session: AsyncSession, initial_data: dict) -> list[User]: | ||
return list((await User.fetch(session)).all()) | ||
|
||
app.state.redis_pool = mocker.AsyncMock() | ||
app.dependency_overrides[get_async_session] = get_session_override | ||
async with AsyncClient( | ||
transport=ASGITransport(app=app), base_url="http://localhost" | ||
) as client: | ||
yield client | ||
app.dependency_overrides.clear() | ||
|
||
@pytest.fixture | ||
async def user(users: list[User]) -> User: | ||
return [usr for usr in users if usr.email == "[email protected]"][0] | ||
|
||
|
||
async def _get_users(session: AsyncSession, email: str | None = None) -> list[User]: | ||
filters = [] | ||
if email: | ||
filters.append(User.email == email) | ||
result = await User.fetch(session, filters=filters) | ||
users = [result.one()] if email else result.all() | ||
return users | ||
@pytest.fixture | ||
async def secondary_user(users: list[User], user: User) -> User: | ||
return [usr for usr in users if usr.email != user.email][0] | ||
|
||
|
||
@pytest.fixture | ||
async def users(session: AsyncSession, initial_data: dict) -> list[User]: | ||
return await _get_users(session) | ||
async def _client_factory(_scoped_session: async_scoped_session, mocker): | ||
async def get_session_override(): | ||
async with _scoped_session() as session: | ||
yield session | ||
|
||
async def _create_client(): | ||
app.state.redis_pool = mocker.AsyncMock() | ||
app.dependency_overrides[get_async_session] = get_session_override | ||
async with AsyncClient( | ||
transport=ASGITransport(app=app), base_url="http://localhost" | ||
) as client: | ||
yield client | ||
app.dependency_overrides.clear() | ||
|
||
return _create_client | ||
|
||
|
||
@pytest.fixture | ||
async def user(users: list[User]) -> User: | ||
return [usr for usr in users if usr.email == "[email protected]"][0] | ||
async def client(_client_factory): | ||
async for client in _client_factory(): | ||
yield client | ||
|
||
|
||
@pytest.fixture | ||
async def secondary_client(_client_factory): | ||
async for client in _client_factory(): | ||
yield client | ||
|
||
|
||
async def _get_auth_token(initial_data: dict, user: User, client: AsyncClient) -> str: | ||
users_dump = initial_data["users"] | ||
user_dump = [item for item in users_dump if item["email"] == user.email][0] | ||
|
||
async def _get_auth_token(client: AsyncClient, *, email: str, password: str) -> str: | ||
# Authenticate and get the token. | ||
form_data = {"username": email, "password": password} | ||
form_data = {"username": user.email, "password": user_dump["password"]} | ||
response = await client.post("/v1/auth/access-token", data=form_data) | ||
assert response.status_code == 200 | ||
|
||
|
@@ -146,60 +175,29 @@ async def _get_auth_token(client: AsyncClient, *, email: str, password: str) -> | |
|
||
|
||
@pytest.fixture | ||
async def auth_tokens( | ||
initial_data: dict, users: list[User], client: AsyncClient | ||
) -> list[dict[str, str]]: | ||
users_dump = initial_data["users"] | ||
tokens = [] | ||
for user in users: | ||
user_dump = [item for item in users_dump if item["email"] == user.email][0] | ||
token = await _get_auth_token( | ||
client, email=user.email, password=user_dump["password"] | ||
) | ||
tokens.append({"email": user.email, "token": token}) | ||
return tokens | ||
|
||
|
||
@pytest.fixture | ||
async def auth_token(initial_data: dict, user: User, client: AsyncClient) -> str: | ||
users_dump = initial_data["users"] | ||
user_dump = [item for item in users_dump if item["email"] == user.email][0] | ||
return await _get_auth_token( | ||
client, email=user.email, password=user_dump["password"] | ||
) | ||
|
||
|
||
@pytest.fixture | ||
async def auth_client(client: AsyncClient, auth_token: str): | ||
client.headers.update({"Authorization": f"Bearer {auth_token}"}) | ||
async def auth_client(initial_data: dict, user: User, client: AsyncClient): | ||
token = await _get_auth_token(initial_data, user, client) | ||
client.headers.update({"Authorization": f"Bearer {token}"}) | ||
return client | ||
|
||
|
||
@pytest.fixture | ||
async def secondary_auth_client( | ||
user: User, client: AsyncClient, auth_tokens: list[dict[str, str]] | ||
initial_data: dict, secondary_user: User, secondary_client: AsyncClient | ||
): | ||
token = [tkn for tkn in auth_tokens if tkn["email"] != user.email][0] | ||
client.headers.update({"Authorization": f"Bearer {token}"}) | ||
return client | ||
token = await _get_auth_token(initial_data, secondary_user, secondary_client) | ||
secondary_client.headers.update({"Authorization": f"Bearer {token}"}) | ||
return secondary_client | ||
|
||
|
||
@pytest.fixture | ||
async def cart_items( | ||
session: AsyncSession, initial_data: dict, user: User | ||
) -> list[CartItem]: | ||
async def _create_cart_with_items(session: AsyncSession, user: User) -> list[CartItem]: | ||
cart = Cart(user_id=user.id) | ||
session.add(cart) | ||
await session.commit() | ||
await session.refresh(cart) | ||
|
||
items = [] | ||
for ice_data in initial_data["icecream"]: | ||
icecream = ( | ||
await IceCream.fetch( | ||
session, filters=[IceCream.flavor == ice_data["flavor"]] | ||
) | ||
).one() | ||
for icecream in (await IceCream.fetch(session)).all(): | ||
cart_item = CartItem( | ||
cart_id=cart.id, | ||
icecream_id=icecream.id, | ||
|
@@ -210,13 +208,34 @@ async def cart_items( | |
session.add_all(items) | ||
await session.commit() | ||
|
||
items = list( | ||
( | ||
await CartItem.fetch( | ||
session, | ||
filters=[CartItem.id == cart.id], | ||
joinedloads=[CartItem.icecream], | ||
) | ||
).all() | ||
) | ||
return items | ||
|
||
|
||
@pytest.fixture | ||
async def cart_items(session: AsyncSession, user: User) -> list[CartItem]: | ||
return await _create_cart_with_items(session, user) | ||
|
||
|
||
@pytest.fixture | ||
async def secondary_cart_items( | ||
session: AsyncSession, secondary_user: User | ||
) -> list[CartItem]: | ||
return await _create_cart_with_items(session, secondary_user) | ||
|
||
|
||
@pytest.fixture | ||
async def order(session: AsyncSession, cart_items: list[CartItem], user: User) -> Order: | ||
cart_service = CartService(session) | ||
cart = await cart_service.get_cart(user.id) | ||
cart = await cart_service.ensure_cart(cast(int, user.id)) | ||
order_service = OrderService(session, stats_service=stats_service) | ||
order = await order_service.make_order_from_cart(cart) | ||
await session.commit() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters