Skip to content

Commit

Permalink
still buggy read
Browse files Browse the repository at this point in the history
  • Loading branch information
extreme4all committed Aug 24, 2024
1 parent 0b94ab0 commit d48eb9a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 56 deletions.
49 changes: 19 additions & 30 deletions osrs_bot_detector_db/repositories/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ async def create(self, **kwargs):
:return: Created model instance
"""
async with self.db_session.begin():
sql = insert(self.model).values(**kwargs).returning(self.model)
result = await self.db_session.execute(sql)
return result.fetchone()
sql = insert(self.model).values(**kwargs).prefix_with("ignore")
_ = await self.db_session.execute(sql)
await self.db_session.commit()

async def read(self, **kwargs):
async def request(self, limit: int, **kwargs):
"""
Asynchronously read a record based on provided field(s).
:param kwargs: Field name(s) and value(s) to filter by
Expand All @@ -48,11 +48,9 @@ async def read(self, **kwargs):
)
filters.append(getattr(self.model, field) == value)

async with self.db_session() as session:
session: AsyncSession # must have type hint
sql = select(self.model).where(*filters)
result = await session.execute(sql)
return result.fetchone()
sql = select(self.model).where(*filters).limit(limit)
result = await self.db_session.execute(sql)
return result.mappings().all() # i don't understand why this is not good

async def update(self, id_value, **kwargs):
"""
Expand All @@ -62,17 +60,14 @@ async def update(self, id_value, **kwargs):
:return: Updated model instance
"""
primary_key_column = self._get_primary_key_column()
async with self.db_session() as session:
session: AsyncSession # must have type hint
async with session.begin():
sql = (
update(self.model)
.where(primary_key_column == id_value)
.values(**kwargs)
.returning(self.model)
)
result = await session.execute(sql)
return result.fetchone()
async with self.db_session.begin():
sql = (
update(self.model)
.where(primary_key_column == id_value)
.values(**kwargs)
)
_ = await self.db_session.execute(sql)
await self.db_session.commit()

async def delete(self, id_value):
"""
Expand All @@ -81,13 +76,7 @@ async def delete(self, id_value):
:return: Boolean indicating success
"""
primary_key_column = self._get_primary_key_column()
async with self.db_session() as session:
session: AsyncSession # must have type hint
async with session.begin():
sql = (
delete(self.model)
.where(primary_key_column == id_value)
.returning(self.model)
)
result = await session.execute(sql)
return result.rowcount > 0
async with self.db_session.begin():
sql = delete(self.model).where(primary_key_column == id_value)
_ = await self.db_session.execute(sql)
await self.db_session.commit()
8 changes: 4 additions & 4 deletions osrs_bot_detector_db/repositories/label_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ async def create(self, model: LabelCreate) -> LabelTable:
kwargs = model.model_dump(exclude_none=True, exclude_unset=True)
return await super().create(**kwargs)

async def read(self, model: LabelResponse, limit: int = 100) -> LabelTable:
async def request(self, limit: int = 100, **kwargs) -> LabelTable:
"""
Asynchronously read a record based on provided field(s).
:param kwargs: Field name(s) and value(s) to filter by
:return: LabelTable instance or None
"""
kwargs = self._convert_to_kwargs(model)
kwargs["limit"] = limit if limit < 100 else limit
return await super().read(**kwargs)
labels = await super().request(limit=limit, **kwargs)
print(labels)
return [LabelResponse(**l) for l in labels]

async def update(self, id_value: int, model: LabelUpdate) -> LabelTable:
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ async def setup_database():
await engine.dispose()


@pytest.fixture(scope="function")
@asynccontextmanager
import pytest_asyncio


@pytest_asyncio.fixture(scope="session")
async def session(setup_database):
"""
Provide a new session for each test function.
Expand Down
50 changes: 45 additions & 5 deletions tests/test_label_repo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
import pytest
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession

from osrs_bot_detector_db.repositories.label_repository import LabelRepository
from osrs_bot_detector_db.schemas.label import LabelCreate
from osrs_bot_detector_db.schemas.label import LabelCreate, LabelResponse, LabelUpdate


@pytest.mark.asyncio
async def test_create_label(session: AsyncSession):
async with session as db_session:
label_repo = LabelRepository(db_session=db_session)
label_create = LabelCreate(label="tester")
label_repo = LabelRepository(db_session=session)
await label_repo.create(model=LabelCreate(label="Tester"))

await label_repo.create(label_create)

@pytest.mark.asyncio
async def test_read_label(session: AsyncSession):
label_repo = LabelRepository(db_session=session)
label = await label_repo.request(label="Tester")
print(f"{label=}")


# @pytest.mark.asyncio
# async def test_update_label(session: AsyncSession):
# label_crud = LabelRepository(db_session=session)
# create_model = LabelCreate(label="OldLabel")

# created_label = await label_crud.create(model=create_model)

# update_model = LabelUpdate(label="NewLabel")
# updated_label = await label_crud.update(
# id_value=created_label.id, model=update_model
# )

# assert updated_label is not None
# assert updated_label.label == "NewLabel"


# @pytest.mark.asyncio
# async def test_delete_label(session: AsyncSession):
# label_crud = LabelRepository(db_session=session)
# create_model = LabelCreate(label="DeleteMe")

# created_label = await label_crud.create(model=create_model)

# deletion_success = await label_crud.delete(id_value=created_label.id)

# assert deletion_success is True

# # Ensure the label is actually deleted
# try:
# await label_crud.read(label="DeleteMe")
# assert False, "Label was not deleted"
# except Exception:
# pass
29 changes: 14 additions & 15 deletions tests/test_player_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
# TODO: try create duplicate player
@pytest.mark.asyncio
async def test_create_player(session: AsyncSession):
async with session as db_session:
player_repository = PlayerRepository(db_session=db_session)
player_create = PlayerCreate(
name=f"Test_Player_{str(uuid.uuid4())[-4:]}",
possible_ban=False,
confirmed_ban=False,
confirmed_player=True,
label_id=1,
label_jagex=2,
ironman=False,
hardcore_ironman=False,
ultimate_ironman=False,
normalized_name="test player",
)
player_repository = PlayerRepository(db_session=session)
player_create = PlayerCreate(
name=f"Test_Player_{str(uuid.uuid4())[-4:]}",
possible_ban=False,
confirmed_ban=False,
confirmed_player=True,
label_id=1,
label_jagex=2,
ironman=False,
hardcore_ironman=False,
ultimate_ironman=False,
normalized_name="test player",
)

await player_repository.create(player_create)
await player_repository.create(player_create)

0 comments on commit d48eb9a

Please sign in to comment.