From ffb683f49bf273fe7961045842ac7bd0aa0633fa Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Tue, 14 May 2024 12:53:17 -0300 Subject: [PATCH] Add search classes to strict type checking (#1841) * Add type hints to search classes * Set future annotations * Update to string annotations * Don't use recursive type references, due to ugly syntax and incompatibility with union operator * update type hint a bit --- pyproject.toml | 6 +++ src/palace/manager/search/document.py | 43 +++++++++++--------- src/palace/manager/search/external_search.py | 6 +-- src/palace/manager/search/migrator.py | 14 ++++--- src/palace/manager/search/revision.py | 4 +- src/palace/manager/search/service.py | 42 +++++++++++-------- src/palace/manager/search/v5.py | 2 +- src/palace/manager/sqlalchemy/model/work.py | 4 +- tests/manager/search/test_external_search.py | 8 ++-- tests/manager/search/test_service.py | 5 +-- 10 files changed, 79 insertions(+), 55 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5d611555f4..d728cc1fc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,12 @@ module = [ "palace.manager.core.selftest", "palace.manager.feed.*", "palace.manager.integration.*", + "palace.manager.search.document", + "palace.manager.search.migrator", + "palace.manager.search.revision", + "palace.manager.search.revision_directory", + "palace.manager.search.service", + "palace.manager.search.v5", "palace.manager.service.*", "palace.manager.sqlalchemy.hassessioncache", "palace.manager.sqlalchemy.model.announcements", diff --git a/src/palace/manager/search/document.py b/src/palace/manager/search/document.py index 6cabe57d67..1e456333a8 100644 --- a/src/palace/manager/search/document.py +++ b/src/palace/manager/search/document.py @@ -1,4 +1,9 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import Any + +SearchMappingSerialization = dict[str, Any] class SearchMappingFieldType(ABC): @@ -10,7 +15,7 @@ class SearchMappingFieldType(ABC): """ @abstractmethod - def serialize(self) -> dict: + def serialize(self) -> SearchMappingSerialization: pass @@ -24,7 +29,7 @@ class SearchMappingFieldTypeScalar(SearchMappingFieldType): def __init__(self, name: str): self._name = name - def serialize(self) -> dict: + def serialize(self) -> dict[str, str]: return {"type": self._name} @@ -77,7 +82,7 @@ def __init__(self, name: str): def parameters(self) -> dict[str, str]: return self._parameters - def serialize(self) -> dict: + def serialize(self) -> dict[str, str]: output = dict(self._parameters) output["type"] = self._name return output @@ -120,13 +125,13 @@ def __init__(self, type: str): def properties(self) -> dict[str, SearchMappingFieldType]: return self._properties - def add_property(self, name, type: SearchMappingFieldType): + def add_property(self, name: str, type: SearchMappingFieldType) -> None: self.properties[name] = type - def serialize(self) -> dict: - output_properties: dict = {} - for name, prop in self._properties.items(): - output_properties[name] = prop.serialize() + def serialize(self) -> SearchMappingSerialization: + output_properties = { + name: prop.serialize() for name, prop in self._properties.items() + } return {"type": self._type, "properties": output_properties} @@ -151,7 +156,7 @@ class SearchMappingFieldTypeCustomBasicText(SearchMappingFieldTypeCustom): that rely on stopwords. """ - def serialize(self) -> dict: + def serialize(self) -> SearchMappingSerialization: return { "type": "text", "analyzer": "en_default_text_analyzer", @@ -181,10 +186,10 @@ class SearchMappingFieldTypeCustomFilterable(SearchMappingFieldTypeCustom): can be used in filters. """ - def __init__(self): + def __init__(self) -> None: self._basic = SearchMappingFieldTypeCustomBasicText() - def serialize(self) -> dict: + def serialize(self) -> SearchMappingSerialization: output = self._basic.serialize() output["fields"]["keyword"] = { "type": "keyword", @@ -203,10 +208,10 @@ def serialize(self) -> dict: class SearchMappingFieldTypeCustomKeyword(SearchMappingFieldTypeCustom): """A custom extension to the keyword type that ensures case-insensitivity.""" - def __init__(self): + def __init__(self) -> None: self._base = keyword() - def serialize(self) -> dict: + def serialize(self) -> SearchMappingSerialization: output = self._base.serialize() output["normalizer"] = "filterable_string" return output @@ -224,13 +229,13 @@ class SearchMappingDocument: See: https://opensearch.org/docs/latest/field-types/index/ """ - def __init__(self): - self._settings: dict[str, dict] = {} + def __init__(self) -> None: + self._settings: dict[str, SearchMappingSerialization] = {} self._fields: dict[str, SearchMappingFieldType] = {} self._scripts: dict[str, str] = {} @property - def settings(self) -> dict[str, dict]: + def settings(self) -> dict[str, SearchMappingSerialization]: return self._settings @property @@ -242,13 +247,13 @@ def properties(self) -> dict[str, SearchMappingFieldType]: return self._fields @properties.setter - def properties(self, fields: dict[str, SearchMappingFieldType]): + def properties(self, fields: dict[str, SearchMappingFieldType]) -> None: self._fields = dict(fields) - def serialize(self) -> dict: + def serialize(self) -> dict[str, SearchMappingSerialization]: output_properties = self.serialize_properties() output_mappings = {"properties": output_properties} return {"settings": self.settings, "mappings": output_mappings} - def serialize_properties(self): + def serialize_properties(self) -> SearchMappingSerialization: return {name: prop.serialize() for name, prop in self._fields.items()} diff --git a/src/palace/manager/search/external_search.py b/src/palace/manager/search/external_search.py index bfe03d133b..81e6c78895 100644 --- a/src/palace/manager/search/external_search.py +++ b/src/palace/manager/search/external_search.py @@ -5,7 +5,7 @@ import re import time from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from attr import define from flask_babel import lazy_gettext as _ @@ -42,7 +42,7 @@ SearchMigrator, ) from palace.manager.search.revision_directory import SearchRevisionDirectory -from palace.manager.search.service import SearchService +from palace.manager.search.service import SearchDocument, SearchService from palace.manager.sqlalchemy.model.contributor import Contributor from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import Identifier @@ -255,7 +255,7 @@ def count_works(self, filter): def create_search_documents_from_works( self, works: Iterable[Work] - ) -> Iterable[dict]: + ) -> Sequence[SearchDocument]: """Create search documents for all the given works.""" if not works: # There's nothing to do. Don't bother making any requests diff --git a/src/palace/manager/search/migrator.py b/src/palace/manager/search/migrator.py index fa76d76c44..a683db86c7 100644 --- a/src/palace/manager/search/migrator.py +++ b/src/palace/manager/search/migrator.py @@ -1,10 +1,14 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Sequence from palace.manager.core.exceptions import BasePalaceException from palace.manager.search.revision import SearchSchemaRevision from palace.manager.search.revision_directory import SearchRevisionDirectory -from palace.manager.search.service import SearchService, SearchServiceFailedDocument +from palace.manager.search.service import ( + SearchDocument, + SearchService, + SearchServiceFailedDocument, +) from palace.manager.util.log import LoggerMixin @@ -21,7 +25,7 @@ class SearchDocumentReceiverType(ABC): @abstractmethod def add_documents( - self, documents: Iterable[dict] + self, documents: Sequence[SearchDocument] ) -> list[SearchServiceFailedDocument]: """Submit documents to be indexed.""" @@ -43,7 +47,7 @@ def pointer(self) -> str: return self._pointer def add_documents( - self, documents: Iterable[dict] + self, documents: Sequence[SearchDocument] ) -> list[SearchServiceFailedDocument]: """Submit documents to be indexed.""" return self._service.index_submit_documents( @@ -75,7 +79,7 @@ def __init__( ) def add_documents( - self, documents: Iterable[dict] + self, documents: Sequence[SearchDocument] ) -> list[SearchServiceFailedDocument]: """Submit documents to be indexed.""" return self._receiver.add_documents(documents) diff --git a/src/palace/manager/search/revision.py b/src/palace/manager/search/revision.py index 3b9cb3531b..ec2cbc9a49 100644 --- a/src/palace/manager/search/revision.py +++ b/src/palace/manager/search/revision.py @@ -14,7 +14,7 @@ class SearchSchemaRevision(ABC): # The SEARCH_VERSION variable MUST be populated in the implemented child classes SEARCH_VERSION: int - def __init__(self): + def __init__(self) -> None: if self.SEARCH_VERSION is None: raise ValueError("The SEARCH_VERSION must be defined with an integer value") self._version = self.SEARCH_VERSION @@ -37,5 +37,5 @@ def name_for_indexed_pointer(self, base_name: str) -> str: such as 'circulation-works-v5-indexed'.""" return f"{base_name}-v{self.version}-indexed" - def script_name(self, script_name): + def script_name(self, script_name: str) -> str: return f"simplified.{script_name}.v{self.version}" diff --git a/src/palace/manager/search/service.py b/src/palace/manager/search/service.py index 04c481e20b..6f9aec3b27 100644 --- a/src/palace/manager/search/service.py +++ b/src/palace/manager/search/service.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import re from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Sequence from dataclasses import dataclass +from typing import Any import opensearchpy.helpers from opensearch_dsl import MultiSearch, Search @@ -45,7 +48,7 @@ class SearchServiceFailedDocument: error_exception: str @classmethod - def from_bulk_error(cls, error: dict): + def from_bulk_error(cls, error: dict[str, Any]) -> SearchServiceFailedDocument: """Transform an error dictionary returned from opensearchpy's bulk API to a typed error""" if error.get("index"): error_indexed = error["index"] @@ -68,6 +71,9 @@ def from_bulk_error(cls, error: dict): ) +SearchDocument = dict[str, Any] + + class SearchService(ABC): """The interface we need from services like Opensearch. Essentially, it provides the operations we want with sensible types, rather than the untyped pile of JSON the actual search client provides. @@ -126,7 +132,7 @@ def index_set_mapping(self, revision: SearchSchemaRevision) -> None: def index_submit_documents( self, pointer: str, - documents: Iterable[dict], + documents: Sequence[SearchDocument], ) -> list[SearchServiceFailedDocument]: """Submit search documents to the given index.""" @@ -135,11 +141,11 @@ def write_pointer_set(self, revision: SearchSchemaRevision) -> None: """Atomically set the write pointer to the index for the given revision and base name.""" @abstractmethod - def refresh(self): + def refresh(self) -> None: """Synchronously refresh the service and wait for changes to be completed.""" @abstractmethod - def index_clear_documents(self, pointer: str): + def index_clear_documents(self, pointer: str) -> None: """Clear all search documents in the given index.""" @abstractmethod @@ -151,11 +157,11 @@ def search_multi_client(self, write: bool = False) -> MultiSearch: """Return the underlying search client.""" @abstractmethod - def index_remove_document(self, pointer: str, id: int): + def index_remove_document(self, pointer: str, id: int) -> None: """Remove a specific document from the given index.""" @abstractmethod - def is_pointer_empty(self, pointer: str): + def is_pointer_empty(self, pointer: str) -> bool: """Check to see if a pointer points to an empty index""" @@ -180,9 +186,7 @@ def base_revision_name(self) -> str: def write_pointer(self) -> SearchWritePointer | None: try: - result: dict = self._client.indices.get_alias( - name=self.write_pointer_name() - ) + result = self._client.indices.get_alias(name=self.write_pointer_name()) for name in result.keys(): match = re.search(f"{self.base_revision_name}-v([0-9]+)", string=name) if match: @@ -276,7 +280,7 @@ def _ensure_scripts(self, revision: SearchSchemaRevision) -> None: self._client.put_script(name, script) # type: ignore [misc] ## Seems the types aren't up to date def index_submit_documents( - self, pointer: str, documents: Iterable[dict] + self, pointer: str, documents: Sequence[SearchDocument] ) -> list[SearchServiceFailedDocument]: self.log.info(f"submitting documents to index {pointer}") @@ -313,12 +317,12 @@ def index_submit_documents( return error_results - def index_clear_documents(self, pointer: str): + def index_clear_documents(self, pointer: str) -> None: self._client.delete_by_query( index=pointer, body={"query": {"match_all": {}}}, wait_for_completion=True ) - def refresh(self): + def refresh(self) -> None: self.log.debug(f"waiting for indexes to become ready") self._client.indices.refresh() @@ -336,7 +340,9 @@ def write_pointer_set(self, revision: SearchSchemaRevision) -> None: def read_pointer(self) -> str | None: try: - result: dict = self._client.indices.get_alias(name=self.read_pointer_name()) + result: dict[str, Any] = self._client.indices.get_alias( + name=self.read_pointer_name() + ) for name in result.keys(): if name.startswith(f"{self.base_revision_name}-"): return name @@ -344,12 +350,12 @@ def read_pointer(self) -> str | None: except NotFoundError: return None - def search_client(self, write=False) -> Search: + def search_client(self, write: bool = False) -> Search: return self._search.index( self.read_pointer_name() if not write else self.write_pointer_name() ) - def search_multi_client(self, write=False) -> MultiSearch: + def search_multi_client(self, write: bool = False) -> MultiSearch: return self._multi_search.index( self.read_pointer_name() if not write else self.write_pointer_name() ) @@ -361,10 +367,10 @@ def write_pointer_name(self) -> str: return f"{self.base_revision_name}-search-write" @staticmethod - def _empty(base_name): + def _empty(base_name: str) -> str: return f"{base_name}-empty" - def index_remove_document(self, pointer: str, id: int): + def index_remove_document(self, pointer: str, id: int) -> None: self._client.delete(index=pointer, id=id, doc_type="_doc") def is_pointer_empty(self, pointer: str) -> bool: diff --git a/src/palace/manager/search/v5.py b/src/palace/manager/search/v5.py index 18df6d432e..c0237aa391 100644 --- a/src/palace/manager/search/v5.py +++ b/src/palace/manager/search/v5.py @@ -97,7 +97,7 @@ class SearchV5(SearchSchemaRevision): CHAR_FILTERS[name] = normalizer AUTHOR_CHAR_FILTER_NAMES.append(name) - def __init__(self): + def __init__(self) -> None: super().__init__() self._normalizers = {} diff --git a/src/palace/manager/sqlalchemy/model/work.py b/src/palace/manager/sqlalchemy/model/work.py index 0a6f45be66..5f19627f55 100644 --- a/src/palace/manager/sqlalchemy/model/work.py +++ b/src/palace/manager/sqlalchemy/model/work.py @@ -4,6 +4,7 @@ import sys from collections import Counter +from collections.abc import Sequence from datetime import date, datetime from decimal import Decimal from typing import TYPE_CHECKING, Any, cast @@ -32,6 +33,7 @@ from sqlalchemy.sql.functions import func from palace.manager.core.classifier import Classifier, WorkClassifier +from palace.manager.search.service import SearchDocument from palace.manager.sqlalchemy.constants import DataSourceConstants from palace.manager.sqlalchemy.model.base import Base from palace.manager.sqlalchemy.model.classification import ( @@ -1413,7 +1415,7 @@ def assign_appeals(self, character, language, setting, story, cutoff=0.20): OPENSEARCH_TIME_FORMAT = 'YYYY-MM-DD"T"HH24:MI:SS"."MS' @classmethod - def to_search_documents(cls, works: list[Self]) -> list[dict]: + def to_search_documents(cls, works: list[Self]) -> Sequence[SearchDocument]: """In app to search documents needed to ease off the burden of complex queries from the DB cluster No recursive identifier policy is taken here as using the diff --git a/tests/manager/search/test_external_search.py b/tests/manager/search/test_external_search.py index 8f6cd1a45d..48f38f7398 100644 --- a/tests/manager/search/test_external_search.py +++ b/tests/manager/search/test_external_search.py @@ -4905,15 +4905,17 @@ def test_to_search_documents_with_missing_data( work: Work = db.work(with_license_pool=True) work.presentation_edition_id = None [result] = Work.to_search_documents([work]) - assert result["identifiers"] == None + assert result["identifiers"] is None # Missing just some attributes work = db.work(with_license_pool=True) work.presentation_edition.title = None work.target_age = None [result] = Work.to_search_documents([work]) - assert result["title"] == None - assert result["target_age"]["lower"] == None + assert result["title"] is None + target_age = result["target_age"] + assert isinstance(target_age, dict) + assert target_age["lower"] is None def test_success( self, diff --git a/tests/manager/search/test_service.py b/tests/manager/search/test_service.py index 46117fa145..ef88411eeb 100644 --- a/tests/manager/search/test_service.py +++ b/tests/manager/search/test_service.py @@ -1,7 +1,6 @@ -from collections.abc import Iterable - from palace.manager.search.document import LONG, SearchMappingDocument from palace.manager.search.revision import SearchSchemaRevision +from palace.manager.search.service import SearchDocument from tests.fixtures.search import ExternalSearchFixture @@ -104,7 +103,7 @@ def test_populate_index_idempotent( # The format expected by the opensearch bulk helper is completely undocumented. # It does, however, appear to use mostly the same format as the Elasticsearch equivalent. # See: https://elasticsearch-py.readthedocs.io/en/v7.13.1/helpers.html#bulk-helpers - documents: Iterable[dict] = [ + documents: list[SearchDocument] = [ { "_index": revision.name_for_index("base"), "_type": "_doc",