Skip to content

Commit

Permalink
Add search classes to strict type checking (#1841)
Browse files Browse the repository at this point in the history
* Add type hints to search classes
* Set future annotations
* Update to string annotations
* Don't use recursive type references, due to ugly syntax and incompatibility with union operator
* update type hint a bit
  • Loading branch information
jonathangreen committed May 24, 2024
1 parent acf0b8f commit ffb683f
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 55 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ module = [
"palace.manager.core.selftest",
"palace.manager.feed.*",
"palace.manager.integration.*",
"palace.manager.search.document",
"palace.manager.search.migrator",
"palace.manager.search.revision",
"palace.manager.search.revision_directory",
"palace.manager.search.service",
"palace.manager.search.v5",
"palace.manager.service.*",
"palace.manager.sqlalchemy.hassessioncache",
"palace.manager.sqlalchemy.model.announcements",
Expand Down
43 changes: 24 additions & 19 deletions src/palace/manager/search/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

SearchMappingSerialization = dict[str, Any]


class SearchMappingFieldType(ABC):
Expand All @@ -10,7 +15,7 @@ class SearchMappingFieldType(ABC):
"""

@abstractmethod
def serialize(self) -> dict:
def serialize(self) -> SearchMappingSerialization:
pass


Expand All @@ -24,7 +29,7 @@ class SearchMappingFieldTypeScalar(SearchMappingFieldType):
def __init__(self, name: str):
self._name = name

def serialize(self) -> dict:
def serialize(self) -> dict[str, str]:
return {"type": self._name}


Expand Down Expand Up @@ -77,7 +82,7 @@ def __init__(self, name: str):
def parameters(self) -> dict[str, str]:
return self._parameters

def serialize(self) -> dict:
def serialize(self) -> dict[str, str]:
output = dict(self._parameters)
output["type"] = self._name
return output
Expand Down Expand Up @@ -120,13 +125,13 @@ def __init__(self, type: str):
def properties(self) -> dict[str, SearchMappingFieldType]:
return self._properties

def add_property(self, name, type: SearchMappingFieldType):
def add_property(self, name: str, type: SearchMappingFieldType) -> None:
self.properties[name] = type

def serialize(self) -> dict:
output_properties: dict = {}
for name, prop in self._properties.items():
output_properties[name] = prop.serialize()
def serialize(self) -> SearchMappingSerialization:
output_properties = {
name: prop.serialize() for name, prop in self._properties.items()
}

return {"type": self._type, "properties": output_properties}

Expand All @@ -151,7 +156,7 @@ class SearchMappingFieldTypeCustomBasicText(SearchMappingFieldTypeCustom):
that rely on stopwords.
"""

def serialize(self) -> dict:
def serialize(self) -> SearchMappingSerialization:
return {
"type": "text",
"analyzer": "en_default_text_analyzer",
Expand Down Expand Up @@ -181,10 +186,10 @@ class SearchMappingFieldTypeCustomFilterable(SearchMappingFieldTypeCustom):
can be used in filters.
"""

def __init__(self):
def __init__(self) -> None:
self._basic = SearchMappingFieldTypeCustomBasicText()

def serialize(self) -> dict:
def serialize(self) -> SearchMappingSerialization:
output = self._basic.serialize()
output["fields"]["keyword"] = {
"type": "keyword",
Expand All @@ -203,10 +208,10 @@ def serialize(self) -> dict:
class SearchMappingFieldTypeCustomKeyword(SearchMappingFieldTypeCustom):
"""A custom extension to the keyword type that ensures case-insensitivity."""

def __init__(self):
def __init__(self) -> None:
self._base = keyword()

def serialize(self) -> dict:
def serialize(self) -> SearchMappingSerialization:
output = self._base.serialize()
output["normalizer"] = "filterable_string"
return output
Expand All @@ -224,13 +229,13 @@ class SearchMappingDocument:
See: https://opensearch.org/docs/latest/field-types/index/
"""

def __init__(self):
self._settings: dict[str, dict] = {}
def __init__(self) -> None:
self._settings: dict[str, SearchMappingSerialization] = {}
self._fields: dict[str, SearchMappingFieldType] = {}
self._scripts: dict[str, str] = {}

@property
def settings(self) -> dict[str, dict]:
def settings(self) -> dict[str, SearchMappingSerialization]:
return self._settings

@property
Expand All @@ -242,13 +247,13 @@ def properties(self) -> dict[str, SearchMappingFieldType]:
return self._fields

@properties.setter
def properties(self, fields: dict[str, SearchMappingFieldType]):
def properties(self, fields: dict[str, SearchMappingFieldType]) -> None:
self._fields = dict(fields)

def serialize(self) -> dict:
def serialize(self) -> dict[str, SearchMappingSerialization]:
output_properties = self.serialize_properties()
output_mappings = {"properties": output_properties}
return {"settings": self.settings, "mappings": output_mappings}

def serialize_properties(self):
def serialize_properties(self) -> SearchMappingSerialization:
return {name: prop.serialize() for name, prop in self._fields.items()}
6 changes: 3 additions & 3 deletions src/palace/manager/search/external_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import time
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Iterable, Sequence

from attr import define
from flask_babel import lazy_gettext as _
Expand Down Expand Up @@ -42,7 +42,7 @@
SearchMigrator,
)
from palace.manager.search.revision_directory import SearchRevisionDirectory
from palace.manager.search.service import SearchService
from palace.manager.search.service import SearchDocument, SearchService
from palace.manager.sqlalchemy.model.contributor import Contributor
from palace.manager.sqlalchemy.model.edition import Edition
from palace.manager.sqlalchemy.model.identifier import Identifier
Expand Down Expand Up @@ -255,7 +255,7 @@ def count_works(self, filter):

def create_search_documents_from_works(
self, works: Iterable[Work]
) -> Iterable[dict]:
) -> Sequence[SearchDocument]:
"""Create search documents for all the given works."""
if not works:
# There's nothing to do. Don't bother making any requests
Expand Down
14 changes: 9 additions & 5 deletions src/palace/manager/search/migrator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Sequence

from palace.manager.core.exceptions import BasePalaceException
from palace.manager.search.revision import SearchSchemaRevision
from palace.manager.search.revision_directory import SearchRevisionDirectory
from palace.manager.search.service import SearchService, SearchServiceFailedDocument
from palace.manager.search.service import (
SearchDocument,
SearchService,
SearchServiceFailedDocument,
)
from palace.manager.util.log import LoggerMixin


Expand All @@ -21,7 +25,7 @@ class SearchDocumentReceiverType(ABC):

@abstractmethod
def add_documents(
self, documents: Iterable[dict]
self, documents: Sequence[SearchDocument]
) -> list[SearchServiceFailedDocument]:
"""Submit documents to be indexed."""

Expand All @@ -43,7 +47,7 @@ def pointer(self) -> str:
return self._pointer

def add_documents(
self, documents: Iterable[dict]
self, documents: Sequence[SearchDocument]
) -> list[SearchServiceFailedDocument]:
"""Submit documents to be indexed."""
return self._service.index_submit_documents(
Expand Down Expand Up @@ -75,7 +79,7 @@ def __init__(
)

def add_documents(
self, documents: Iterable[dict]
self, documents: Sequence[SearchDocument]
) -> list[SearchServiceFailedDocument]:
"""Submit documents to be indexed."""
return self._receiver.add_documents(documents)
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/search/revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SearchSchemaRevision(ABC):
# The SEARCH_VERSION variable MUST be populated in the implemented child classes
SEARCH_VERSION: int

def __init__(self):
def __init__(self) -> None:
if self.SEARCH_VERSION is None:
raise ValueError("The SEARCH_VERSION must be defined with an integer value")
self._version = self.SEARCH_VERSION
Expand All @@ -37,5 +37,5 @@ def name_for_indexed_pointer(self, base_name: str) -> str:
such as 'circulation-works-v5-indexed'."""
return f"{base_name}-v{self.version}-indexed"

def script_name(self, script_name):
def script_name(self, script_name: str) -> str:
return f"simplified.{script_name}.v{self.version}"
42 changes: 24 additions & 18 deletions src/palace/manager/search/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import re
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any

import opensearchpy.helpers
from opensearch_dsl import MultiSearch, Search
Expand Down Expand Up @@ -45,7 +48,7 @@ class SearchServiceFailedDocument:
error_exception: str

@classmethod
def from_bulk_error(cls, error: dict):
def from_bulk_error(cls, error: dict[str, Any]) -> SearchServiceFailedDocument:
"""Transform an error dictionary returned from opensearchpy's bulk API to a typed error"""
if error.get("index"):
error_indexed = error["index"]
Expand All @@ -68,6 +71,9 @@ def from_bulk_error(cls, error: dict):
)


SearchDocument = dict[str, Any]


class SearchService(ABC):
"""The interface we need from services like Opensearch. Essentially, it provides the operations we want with
sensible types, rather than the untyped pile of JSON the actual search client provides.
Expand Down Expand Up @@ -126,7 +132,7 @@ def index_set_mapping(self, revision: SearchSchemaRevision) -> None:
def index_submit_documents(
self,
pointer: str,
documents: Iterable[dict],
documents: Sequence[SearchDocument],
) -> list[SearchServiceFailedDocument]:
"""Submit search documents to the given index."""

Expand All @@ -135,11 +141,11 @@ def write_pointer_set(self, revision: SearchSchemaRevision) -> None:
"""Atomically set the write pointer to the index for the given revision and base name."""

@abstractmethod
def refresh(self):
def refresh(self) -> None:
"""Synchronously refresh the service and wait for changes to be completed."""

@abstractmethod
def index_clear_documents(self, pointer: str):
def index_clear_documents(self, pointer: str) -> None:
"""Clear all search documents in the given index."""

@abstractmethod
Expand All @@ -151,11 +157,11 @@ def search_multi_client(self, write: bool = False) -> MultiSearch:
"""Return the underlying search client."""

@abstractmethod
def index_remove_document(self, pointer: str, id: int):
def index_remove_document(self, pointer: str, id: int) -> None:
"""Remove a specific document from the given index."""

@abstractmethod
def is_pointer_empty(self, pointer: str):
def is_pointer_empty(self, pointer: str) -> bool:
"""Check to see if a pointer points to an empty index"""


Expand All @@ -180,9 +186,7 @@ def base_revision_name(self) -> str:

def write_pointer(self) -> SearchWritePointer | None:
try:
result: dict = self._client.indices.get_alias(
name=self.write_pointer_name()
)
result = self._client.indices.get_alias(name=self.write_pointer_name())
for name in result.keys():
match = re.search(f"{self.base_revision_name}-v([0-9]+)", string=name)
if match:
Expand Down Expand Up @@ -276,7 +280,7 @@ def _ensure_scripts(self, revision: SearchSchemaRevision) -> None:
self._client.put_script(name, script) # type: ignore [misc] ## Seems the types aren't up to date

def index_submit_documents(
self, pointer: str, documents: Iterable[dict]
self, pointer: str, documents: Sequence[SearchDocument]
) -> list[SearchServiceFailedDocument]:
self.log.info(f"submitting documents to index {pointer}")

Expand Down Expand Up @@ -313,12 +317,12 @@ def index_submit_documents(

return error_results

def index_clear_documents(self, pointer: str):
def index_clear_documents(self, pointer: str) -> None:
self._client.delete_by_query(
index=pointer, body={"query": {"match_all": {}}}, wait_for_completion=True
)

def refresh(self):
def refresh(self) -> None:
self.log.debug(f"waiting for indexes to become ready")
self._client.indices.refresh()

Expand All @@ -336,20 +340,22 @@ def write_pointer_set(self, revision: SearchSchemaRevision) -> None:

def read_pointer(self) -> str | None:
try:
result: dict = self._client.indices.get_alias(name=self.read_pointer_name())
result: dict[str, Any] = self._client.indices.get_alias(
name=self.read_pointer_name()
)
for name in result.keys():
if name.startswith(f"{self.base_revision_name}-"):
return name
return None
except NotFoundError:
return None

def search_client(self, write=False) -> Search:
def search_client(self, write: bool = False) -> Search:
return self._search.index(
self.read_pointer_name() if not write else self.write_pointer_name()
)

def search_multi_client(self, write=False) -> MultiSearch:
def search_multi_client(self, write: bool = False) -> MultiSearch:
return self._multi_search.index(
self.read_pointer_name() if not write else self.write_pointer_name()
)
Expand All @@ -361,10 +367,10 @@ def write_pointer_name(self) -> str:
return f"{self.base_revision_name}-search-write"

@staticmethod
def _empty(base_name):
def _empty(base_name: str) -> str:
return f"{base_name}-empty"

def index_remove_document(self, pointer: str, id: int):
def index_remove_document(self, pointer: str, id: int) -> None:
self._client.delete(index=pointer, id=id, doc_type="_doc")

def is_pointer_empty(self, pointer: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/palace/manager/search/v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class SearchV5(SearchSchemaRevision):
CHAR_FILTERS[name] = normalizer
AUTHOR_CHAR_FILTER_NAMES.append(name)

def __init__(self):
def __init__(self) -> None:
super().__init__()

self._normalizers = {}
Expand Down
Loading

0 comments on commit ffb683f

Please sign in to comment.