diff --git a/registry/drivers.py b/registry/drivers.py index aa6d81d..6356f72 100644 --- a/registry/drivers.py +++ b/registry/drivers.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Optional, Protocol from registry.entity import Entity from registry.schema import StorageDriver @@ -8,14 +8,23 @@ class Driver(Protocol): def __init__(self, dsn: str) -> None: ... - async def create(self, entity: type[Entity], data: dict) -> dict: - raise NotImplementedError() - async def find( - self, entity: type[Entity], queries: list[dict] + self, + entity: type[Entity], + queries: list[dict], + limit: Optional[int] = None, ) -> list[dict]: raise NotImplementedError() + async def find_one( + self, + entity: type[Entity], + queries: list[dict], + ) -> Optional[dict]: + rows = await self.find(entity, queries, limit=1) + if len(rows): + return rows[0] + async def find_or_create( self, entity: type[Entity], query: dict, data: dict ) -> dict: @@ -23,11 +32,25 @@ async def find_or_create( if len(result): return result[0] - return await self.create(entity, data) + return await self.insert(entity, data) + + async def find_or_fail( + self, + entity: type[Entity], + queries: list[dict], + ) -> dict: + instance = await self.find_one(entity, queries) + if not instance: + raise LookupError(f'{entity.nick} not found') + + return instance async def init_schema(self, entity: type[Entity]) -> None: raise NotImplementedError() + async def insert(self, entity: type[Entity], data: dict) -> dict: + raise NotImplementedError() + driver_instances: dict[str, dict[str, Driver]] = {} @@ -53,25 +76,31 @@ class MemoryDriver(Driver): def __init__(self, dsn: str) -> None: self.data: dict[type[Entity], list[dict]] = {} - async def create(self, entity: type[Entity], data: dict) -> dict: - await self.init_schema(entity) - data['id'] = len(self.data[entity]) + 1 - self.data[entity].append(data) - return data - async def find( - self, entity: type[Entity], queries: list[dict] + self, + entity: type[Entity], + queries: list[dict], + limit: Optional[int] = None, ) -> list[dict]: await self.init_schema(entity) - return [ + rows = [ row for row in self.data[entity] if await self.is_valid(row, queries) ] + if limit: + rows = rows[0:limit] + return rows async def init_schema(self, entity: type[Entity]) -> None: if entity not in self.data: self.data[entity] = [] + async def insert(self, entity: type[Entity], data: dict) -> dict: + await self.init_schema(entity) + data['id'] = len(self.data[entity]) + 1 + self.data[entity].append(data) + return data + async def is_valid(self, row, queries: list) -> bool: for query in queries: if False not in [ diff --git a/registry/registry.py b/registry/registry.py index d435503..6d743ef 100644 --- a/registry/registry.py +++ b/registry/registry.py @@ -89,7 +89,7 @@ async def create( context = await self.context(entity, key) return context.repository.make( entity=entity, - row=await context.driver.create( + row=await context.driver.insert( entity=entity, data=dict(bucket_id=context.bucket.id, **data), ) diff --git a/tests/test_repository.py b/tests/test_repository.py index 3bd5ab3..645b635 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -33,18 +33,35 @@ class ActionRepository(Repository): async def test_hello(): registry = Registry() assert len(await registry.find(Action)) == 0 + + # create two actions action1 = await registry.find_or_create(Action, {'type': 'tester'}) action2 = await registry.find_or_create(Action, {'type': 'tester2'}) + + # validate properties assert action1.id == 1 assert action1.type == 'tester' assert action1.owner_id == 0 assert action2.id == 2 assert action2.type == 'tester2' assert action2.owner_id == 0 + + # identity map + action3 = await registry.find_or_create(Action, {'id': 2}) + assert action3 == action2 + action4 = await registry.find_or_create(Action, {'type': 'tester'}) + assert action4 == action1 assert len(await registry.find(Action)) == 2 + + # lookup checks assert (await registry.get_instance(Action, 2)).type == 'tester2' assert (await registry.get_instance(Action, 3)) is None + # storage level persistence check [storage] = registry.storages driver = await get_driver(storage.driver, storage.dsn) - assert driver.data[Action][0]['owner_id'] == 0 + assert len(await driver.find(Action, queries=[{}])) == 2 + + # default values peristence + first_action_dict = await driver.find_one(Action, queries=[{'id': 1}]) + assert first_action_dict['owner_id'] == 0