diff --git a/osrs_bot_detector_db/repositories/crud.py b/osrs_bot_detector_db/repositories/crud.py index 1bca4c8..d34fa42 100644 --- a/osrs_bot_detector_db/repositories/crud.py +++ b/osrs_bot_detector_db/repositories/crud.py @@ -30,11 +30,11 @@ async def create(self, **kwargs): :return: Created model instance """ async with self.db_session.begin(): - sql = insert(self.model).values(**kwargs).returning(self.model) - result = await self.db_session.execute(sql) - return result.fetchone() + sql = insert(self.model).values(**kwargs).prefix_with("ignore") + _ = await self.db_session.execute(sql) + await self.db_session.commit() - async def read(self, **kwargs): + async def request(self, limit: int, **kwargs): """ Asynchronously read a record based on provided field(s). :param kwargs: Field name(s) and value(s) to filter by @@ -48,11 +48,9 @@ async def read(self, **kwargs): ) filters.append(getattr(self.model, field) == value) - async with self.db_session() as session: - session: AsyncSession # must have type hint - sql = select(self.model).where(*filters) - result = await session.execute(sql) - return result.fetchone() + sql = select(self.model).where(*filters).limit(limit) + result = await self.db_session.execute(sql) + return result.mappings().all() # i don't understand why this is not good async def update(self, id_value, **kwargs): """ @@ -62,17 +60,14 @@ async def update(self, id_value, **kwargs): :return: Updated model instance """ primary_key_column = self._get_primary_key_column() - async with self.db_session() as session: - session: AsyncSession # must have type hint - async with session.begin(): - sql = ( - update(self.model) - .where(primary_key_column == id_value) - .values(**kwargs) - .returning(self.model) - ) - result = await session.execute(sql) - return result.fetchone() + async with self.db_session.begin(): + sql = ( + update(self.model) + .where(primary_key_column == id_value) + .values(**kwargs) + ) + _ = await self.db_session.execute(sql) + await self.db_session.commit() async def delete(self, id_value): """ @@ -81,13 +76,7 @@ async def delete(self, id_value): :return: Boolean indicating success """ primary_key_column = self._get_primary_key_column() - async with self.db_session() as session: - session: AsyncSession # must have type hint - async with session.begin(): - sql = ( - delete(self.model) - .where(primary_key_column == id_value) - .returning(self.model) - ) - result = await session.execute(sql) - return result.rowcount > 0 + async with self.db_session.begin(): + sql = delete(self.model).where(primary_key_column == id_value) + _ = await self.db_session.execute(sql) + await self.db_session.commit() diff --git a/osrs_bot_detector_db/repositories/label_repository.py b/osrs_bot_detector_db/repositories/label_repository.py index d0d1886..91aeee0 100644 --- a/osrs_bot_detector_db/repositories/label_repository.py +++ b/osrs_bot_detector_db/repositories/label_repository.py @@ -33,15 +33,15 @@ async def create(self, model: LabelCreate) -> LabelTable: kwargs = model.model_dump(exclude_none=True, exclude_unset=True) return await super().create(**kwargs) - async def read(self, model: LabelResponse, limit: int = 100) -> LabelTable: + async def request(self, limit: int = 100, **kwargs) -> LabelTable: """ Asynchronously read a record based on provided field(s). :param kwargs: Field name(s) and value(s) to filter by :return: LabelTable instance or None """ - kwargs = self._convert_to_kwargs(model) - kwargs["limit"] = limit if limit < 100 else limit - return await super().read(**kwargs) + labels = await super().request(limit=limit, **kwargs) + print(labels) + return [LabelResponse(**l) for l in labels] async def update(self, id_value: int, model: LabelUpdate) -> LabelTable: """ diff --git a/tests/conftest.py b/tests/conftest.py index 8e92f6b..244b511 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,8 +28,10 @@ async def setup_database(): await engine.dispose() -@pytest.fixture(scope="function") -@asynccontextmanager +import pytest_asyncio + + +@pytest_asyncio.fixture(scope="session") async def session(setup_database): """ Provide a new session for each test function. diff --git a/tests/test_label_repo.py b/tests/test_label_repo.py index 64c707c..b81efce 100644 --- a/tests/test_label_repo.py +++ b/tests/test_label_repo.py @@ -1,14 +1,54 @@ import pytest +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from osrs_bot_detector_db.repositories.label_repository import LabelRepository -from osrs_bot_detector_db.schemas.label import LabelCreate +from osrs_bot_detector_db.schemas.label import LabelCreate, LabelResponse, LabelUpdate @pytest.mark.asyncio async def test_create_label(session: AsyncSession): - async with session as db_session: - label_repo = LabelRepository(db_session=db_session) - label_create = LabelCreate(label="tester") + label_repo = LabelRepository(db_session=session) + await label_repo.create(model=LabelCreate(label="Tester")) - await label_repo.create(label_create) + +@pytest.mark.asyncio +async def test_read_label(session: AsyncSession): + label_repo = LabelRepository(db_session=session) + label = await label_repo.request(label="Tester") + print(f"{label=}") + + +# @pytest.mark.asyncio +# async def test_update_label(session: AsyncSession): +# label_crud = LabelRepository(db_session=session) +# create_model = LabelCreate(label="OldLabel") + +# created_label = await label_crud.create(model=create_model) + +# update_model = LabelUpdate(label="NewLabel") +# updated_label = await label_crud.update( +# id_value=created_label.id, model=update_model +# ) + +# assert updated_label is not None +# assert updated_label.label == "NewLabel" + + +# @pytest.mark.asyncio +# async def test_delete_label(session: AsyncSession): +# label_crud = LabelRepository(db_session=session) +# create_model = LabelCreate(label="DeleteMe") + +# created_label = await label_crud.create(model=create_model) + +# deletion_success = await label_crud.delete(id_value=created_label.id) + +# assert deletion_success is True + +# # Ensure the label is actually deleted +# try: +# await label_crud.read(label="DeleteMe") +# assert False, "Label was not deleted" +# except Exception: +# pass diff --git a/tests/test_player_repo.py b/tests/test_player_repo.py index e45d9d7..a0bebc9 100644 --- a/tests/test_player_repo.py +++ b/tests/test_player_repo.py @@ -10,19 +10,18 @@ # TODO: try create duplicate player @pytest.mark.asyncio async def test_create_player(session: AsyncSession): - async with session as db_session: - player_repository = PlayerRepository(db_session=db_session) - player_create = PlayerCreate( - name=f"Test_Player_{str(uuid.uuid4())[-4:]}", - possible_ban=False, - confirmed_ban=False, - confirmed_player=True, - label_id=1, - label_jagex=2, - ironman=False, - hardcore_ironman=False, - ultimate_ironman=False, - normalized_name="test player", - ) + player_repository = PlayerRepository(db_session=session) + player_create = PlayerCreate( + name=f"Test_Player_{str(uuid.uuid4())[-4:]}", + possible_ban=False, + confirmed_ban=False, + confirmed_player=True, + label_id=1, + label_jagex=2, + ironman=False, + hardcore_ironman=False, + ultimate_ironman=False, + normalized_name="test player", + ) - await player_repository.create(player_create) + await player_repository.create(player_create)