diff --git a/python/src/algeria_cities/models.py b/python/src/algeria_cities/models.py index 22469a1..0e95688 100644 --- a/python/src/algeria_cities/models.py +++ b/python/src/algeria_cities/models.py @@ -66,13 +66,3 @@ class PostcodeModel(BaseModel): wilaya_name_ascii: str # TODO: will need to look at empty field in JSON post_address_ascii: str = '' - - -class EntryFilter: - def __init__(self, entries: List[Union[CityModel, PostcodeModel]]): - self.entries = entries - - def find_by(self, key: str, value: Any) -> Union[CityModel, PostcodeModel]: - for entry in self.entries: - if getattr(entry, key) == value: - return entry diff --git a/python/src/algeria_cities/search.py b/python/src/algeria_cities/search.py new file mode 100644 index 0000000..d9898cf --- /dev/null +++ b/python/src/algeria_cities/search.py @@ -0,0 +1,43 @@ +from typing import List, Union, Any, Optional +from algeria_cities.models import CityModel, PostcodeModel + +ModelChoice = Union[CityModel, PostcodeModel] +ModelType = Optional[ModelChoice] + + +def filter_out_entry(key: str, value: Any): + def _inner(entry: Any): + test = getattr(entry, key) == value + return not test + return _inner + +class EntryFilter: + def __init__(self, entries: List[Union[CityModel, PostcodeModel]]): + self.entries = entries + self.indexed = dict() + + def index(self, key: str) -> "EntryFilter": + # key must be unique + if getattr(self.entries[0], key) is None: + raise Exception(f"key {key} does not exist") + + self.indexed = dict() + for entry in self.entries: + self.indexed[getattr(entry, key)] = entry + return self + + def get_entry(self, key_value): + # leaf: should not be daisy-chained + return self.indexed.get(key_value) + + def remove(self, key: str, value: Any, model_type: ModelChoice) -> "EntryFilter": + # assert model_type is what it is + assert type(self.entries[0]) == model_type + + self.entries = list(filter(filter_out_entry(key, value), self.entries)) + return self + + def find_by(self, key: str, value: Any) -> ModelType: + for entry in self.entries: + if getattr(entry, key) == value: + return entry diff --git a/python/tests/test_sql_db.py b/python/tests/test_sql_db.py index 2e8d274..00d5930 100644 --- a/python/tests/test_sql_db.py +++ b/python/tests/test_sql_db.py @@ -5,7 +5,8 @@ from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import Session -from algeria_cities.models import Base, City, Postcode, CityModel, PostcodeModel, EntryFilter +from algeria_cities.models import Base, City, Postcode, CityModel, PostcodeModel +from algeria_cities.search import EntryFilter from sqlite3 import Cursor from sqlite3 import Connection from unittest import TestCase @@ -95,6 +96,29 @@ def test_random_postcode_test(self): for column_key in column_names: assert getattr(entry, column_key) == getattr(sql_entry, column_key) + def test_search_by_key_test(self): + file = os.path.join(JSON_DIR, "algeria_postcodes.json") + assert os.path.isfile(file) + entries = get_file_content(file) + filter_obj = EntryFilter(entries) \ + .remove("post_code", "", PostcodeModel) \ + .index("post_code") + + sql_entries: Postcode = self.session.query(Postcode).all() + for index in range(200): + random_entry = randrange(1, 3900) + entry = sql_entries[random_entry] + + if entry.post_code == "": + continue + + json_entry = filter_obj.get_entry(entry.post_code) + column_names = Postcode.__table__.columns.keys() + column_names.remove("id") + + for column_key in column_names: + assert getattr(json_entry, column_key) == getattr(entry, column_key) + def test_for_empty_postcodes(self): sql_entry: Postcode = self.session.query(Postcode).filter(Postcode.post_code == "") current_empty_postcodes = 91