From fc8c0e00c8009da59b2747228692cda35880d139 Mon Sep 17 00:00:00 2001 From: dblock Date: Mon, 6 Nov 2023 16:20:50 -0500 Subject: [PATCH] Added missing types. Signed-off-by: dblock --- benchmarks/bench_async.py | 5 +- benchmarks/bench_info_sync.py | 7 +- benchmarks/bench_sync.py | 7 +- benchmarks/thread_with_return_value.py | 19 ++- docs/source/conf.py | 34 +++-- noxfile.py | 25 +-- opensearchpy/__init__.py | 1 + opensearchpy/_async/helpers/document.py | 14 +- opensearchpy/_async/helpers/index.py | 2 +- opensearchpy/_async/http_aiohttp.py | 2 +- opensearchpy/client/utils.py | 8 +- opensearchpy/connection/async_connections.py | 2 + opensearchpy/connection_pool.py | 6 +- opensearchpy/helpers/actions.py | 10 +- opensearchpy/helpers/asyncsigner.py | 16 +- opensearchpy/helpers/field.py | 4 +- opensearchpy/helpers/index.py | 2 +- opensearchpy/helpers/query.py | 8 +- opensearchpy/helpers/search.py | 3 + opensearchpy/helpers/utils.py | 4 +- opensearchpy/transport.py | 2 +- samples/bulk/bulk-array.py | 3 +- samples/hello/hello-async.py | 2 +- samples/json/json-hello-async.py | 2 +- samples/knn/knn-async-basics.py | 2 +- test_opensearchpy/TestHttpServer.py | 4 +- test_opensearchpy/run_tests.py | 5 +- .../test_async/test_connection.py | 42 ++--- .../test_async/test_helpers/conftest.py | 20 +-- .../test_async/test_helpers/test_document.py | 121 ++++++++------- .../test_helpers/test_faceted_search.py | 11 +- .../test_async/test_helpers/test_index.py | 11 +- .../test_async/test_helpers/test_mapping.py | 6 +- .../test_async/test_helpers/test_search.py | 25 +-- .../test_helpers/test_update_by_query.py | 8 +- .../test_async/test_http_connection.py | 16 +- .../test_async/test_plugins_client.py | 3 +- .../test_async/test_server/conftest.py | 9 +- .../test_async/test_server/test_clients.py | 10 +- .../test_server/test_helpers/conftest.py | 27 ++-- .../test_server/test_helpers/test_actions.py | 135 +++++++++-------- .../test_server/test_helpers/test_data.py | 8 +- .../test_server/test_helpers/test_document.py | 91 ++++++----- .../test_helpers/test_faceted_search.py | 33 ++-- .../test_server/test_helpers/test_index.py | 20 ++- .../test_server/test_helpers/test_mapping.py | 10 +- .../test_server/test_helpers/test_search.py | 22 +-- .../test_helpers/test_update_by_query.py | 14 +- .../test_server/test_plugins/test_alerting.py | 6 +- .../test_server/test_rest_api_spec.py | 35 ++--- .../test_security_plugin.py | 2 +- test_opensearchpy/test_async/test_signer.py | 12 +- .../test_async/test_transport.py | 130 ++++++++-------- test_opensearchpy/test_cases.py | 32 ++-- .../test_plugins/test_plugins_client.py | 3 +- test_opensearchpy/test_client/test_utils.py | 8 +- .../test_connection/test_base_connection.py | 4 +- .../test_requests_http_connection.py | 39 +++-- .../test_urllib3_http_connection.py | 20 ++- test_opensearchpy/test_connection_pool.py | 7 +- test_opensearchpy/test_helpers/conftest.py | 20 +-- .../test_helpers/test_actions.py | 25 +-- test_opensearchpy/test_helpers/test_aggs.py | 26 ++-- .../test_helpers/test_analysis.py | 12 +- .../test_helpers/test_document.py | 143 ++++++++++-------- .../test_helpers/test_faceted_search.py | 11 +- test_opensearchpy/test_helpers/test_field.py | 15 +- test_opensearchpy/test_helpers/test_index.py | 37 ++--- .../test_helpers/test_mapping.py | 6 +- test_opensearchpy/test_helpers/test_query.py | 14 +- test_opensearchpy/test_helpers/test_result.py | 41 ++--- test_opensearchpy/test_helpers/test_search.py | 77 +++++----- .../test_helpers/test_update_by_query.py | 11 +- test_opensearchpy/test_helpers/test_utils.py | 4 +- .../test_helpers/test_validation.py | 45 +++--- .../test_helpers/test_wrappers.py | 21 +-- test_opensearchpy/test_serializer.py | 3 +- test_opensearchpy/test_server/__init__.py | 9 +- test_opensearchpy/test_server/conftest.py | 15 +- .../test_server/test_helpers/conftest.py | 30 ++-- .../test_server/test_helpers/test_actions.py | 66 ++++---- .../test_server/test_helpers/test_analysis.py | 8 +- .../test_server/test_helpers/test_count.py | 8 +- .../test_server/test_helpers/test_data.py | 8 +- .../test_server/test_helpers/test_document.py | 81 +++++----- .../test_helpers/test_faceted_search.py | 33 ++-- .../test_server/test_helpers/test_index.py | 20 +-- .../test_server/test_helpers/test_mapping.py | 10 +- .../test_server/test_helpers/test_search.py | 22 +-- .../test_helpers/test_update_by_query.py | 8 +- .../test_server/test_plugins/test_alerting.py | 6 +- .../test_server/test_rest_api_spec.py | 84 +++++----- .../test_server_secured/test_clients.py | 2 +- .../test_security_plugin.py | 4 +- test_opensearchpy/test_transport.py | 83 +++++----- test_opensearchpy/utils.py | 29 ++-- utils/build-dists.py | 15 +- utils/generate-api.py | 63 ++++---- 98 files changed, 1235 insertions(+), 1019 deletions(-) diff --git a/benchmarks/bench_async.py b/benchmarks/bench_async.py index a27a126c..baeb7d80 100644 --- a/benchmarks/bench_async.py +++ b/benchmarks/bench_async.py @@ -12,6 +12,7 @@ import asyncio import uuid +from typing import Any from opensearchpy import AsyncHttpConnection, AsyncOpenSearch @@ -22,7 +23,7 @@ item_count = 100 -async def index_records(client, item_count) -> None: +async def index_records(client: Any, item_count: int) -> None: await asyncio.gather( *[ client.index( @@ -39,7 +40,7 @@ async def index_records(client, item_count) -> None: ) -async def test_async(client_count=1, item_count=1): +async def test_async(client_count: int = 1, item_count: int = 1) -> None: clients = [] for i in range(client_count): clients.append( diff --git a/benchmarks/bench_info_sync.py b/benchmarks/bench_info_sync.py index 29b289cd..0c69a102 100644 --- a/benchmarks/bench_info_sync.py +++ b/benchmarks/bench_info_sync.py @@ -14,6 +14,7 @@ import logging import sys import time +from typing import Any from thread_with_return_value import ThreadWithReturnValue @@ -36,8 +37,8 @@ root.addHandler(handler) -def get_info(client, request_count): - tt = 0 +def get_info(client: Any, request_count: int) -> float: + tt: float = 0 for n in range(request_count): start = time.time() * 1000 client.info() @@ -46,7 +47,7 @@ def get_info(client, request_count): return tt -def test(thread_count=1, request_count=1, client_count=1): +def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -> None: clients = [] for i in range(client_count): clients.append( diff --git a/benchmarks/bench_sync.py b/benchmarks/bench_sync.py index 83225ef9..004fa2e4 100644 --- a/benchmarks/bench_sync.py +++ b/benchmarks/bench_sync.py @@ -15,6 +15,7 @@ import sys import time import uuid +from typing import Any from thread_with_return_value import ThreadWithReturnValue @@ -37,10 +38,10 @@ root.addHandler(handler) -def index_records(client, item_count): +def index_records(client: Any, item_count: int) -> Any: tt = 0 for n in range(10): - data = [] + data: Any = [] for i in range(item_count): data.append( json.dumps({"index": {"_index": index_name, "_id": str(uuid.uuid4())}}) @@ -63,7 +64,7 @@ def index_records(client, item_count): return tt -def test(thread_count=1, item_count=1, client_count=1): +def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> None: clients = [] for i in range(client_count): clients.append( diff --git a/benchmarks/thread_with_return_value.py b/benchmarks/thread_with_return_value.py index b6bc9c09..089c6fde 100644 --- a/benchmarks/thread_with_return_value.py +++ b/benchmarks/thread_with_return_value.py @@ -10,19 +10,30 @@ from threading import Thread +from typing import Any, Optional class ThreadWithReturnValue(Thread): + _target: Any + _args: Any + _kwargs: Any + def __init__( - self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None - ): + self, + group: Any = None, + target: Any = None, + name: Optional[str] = None, + args: Any = (), + kwargs: Any = {}, + Verbose: Optional[bool] = None, + ) -> None: Thread.__init__(self, group, target, name, args, kwargs) self._return = None - def run(self): + def run(self) -> None: if self._target is not None: self._return = self._target(*self._args, **self._kwargs) - def join(self, *args): + def join(self, *args: Any) -> Any: Thread.join(self, *args) return self._return diff --git a/docs/source/conf.py b/docs/source/conf.py index 133a2564..64ff3c52 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,9 +26,11 @@ # -- Project information ----------------------------------------------------- -project = "OpenSearch Python Client" -copyright = "OpenSearch Project Contributors" -author = "OpenSearch Project Contributors" +from typing import Any + +project: str = "OpenSearch Python Client" +copyright: str = "OpenSearch Project Contributors" +author: str = "OpenSearch Project Contributors" # -- General configuration --------------------------------------------------- @@ -36,7 +38,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ +extensions: Any = [ "sphinx.ext.autodoc", "sphinx_rtd_theme", "sphinx.ext.viewcode", @@ -47,12 +49,12 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] +templates_path: Any = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: Any = [] # -- Options for HTML output ------------------------------------------------- @@ -60,31 +62,31 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "sphinx_rtd_theme" +html_theme: str = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] +html_static_path: Any = ["_static"] # -- additional settings ------------------------------------------------- -intersphinx_mapping = { +intersphinx_mapping: Any = { "python": ("https://docs.python.org/3", None), } -html_logo = "imgs/OpenSearch.svg" +html_logo: str = "imgs/OpenSearch.svg" # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) -html_css_files = [ +html_css_files: Any = [ "css/custom.css", ] # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -html_show_sphinx = False +html_show_sphinx: bool = False # add github link -html_context = { +html_context: Any = { "display_github": True, "github_user": "opensearch-project", "github_repo": "opensearch-py", @@ -94,18 +96,18 @@ # -- autodoc config ------------------------------------------------- # This value controls how to represent typehints. # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_typehints -autodoc_typehints = "description" +autodoc_typehints: str = "description" # This value selects what content will be inserted into the main body of an autoclass directive. # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autoclass_content -autoclass_content = "both" +autoclass_content: str = "both" # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-add_module_names # add_module_names = False # The default options for autodoc directives. # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_default_options -autodoc_default_options = { +autodoc_default_options: Any = { # If set, autodoc will generate document for the members of the target module, class or exception. # noqa: E501 # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#directive-option-automodule-members "members": True, diff --git a/noxfile.py b/noxfile.py index 678d6af8..e9189cc9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -26,6 +26,8 @@ # under the License. +from typing import Any + import nox SOURCE_FILES = ( @@ -40,16 +42,16 @@ ) -@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]) -def test(session) -> None: +@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]) # type: ignore +def test(session: Any) -> None: session.install(".") session.install("-r", "dev-requirements.txt") session.run("python", "setup.py", "test") -@nox.session() -def format(session) -> None: +@nox.session() # type: ignore +def format(session: Any) -> None: session.install("black", "isort") session.run("isort", "--profile=black", *SOURCE_FILES) @@ -59,8 +61,8 @@ def format(session) -> None: lint(session) -@nox.session(python=["3.7"]) -def lint(session) -> None: +@nox.session(python=["3.7"]) # type: ignore +def lint(session: Any) -> None: session.install( "flake8", "black", @@ -70,6 +72,9 @@ def lint(session) -> None: "types-six", "types-simplejson", "types-python-dateutil", + "types-PyYAML", + "types-mock", + "types-pytz", ) session.run("isort", "--check", "--profile=black", *SOURCE_FILES) @@ -93,8 +98,8 @@ def lint(session) -> None: session.run("mypy", "--strict", "test_opensearchpy/test_types/sync_types.py") -@nox.session() -def docs(session) -> None: +@nox.session() # type: ignore +def docs(session: Any) -> None: session.install(".") session.install( "-rdev-requirements.txt", "sphinx-rtd-theme", "sphinx-autodoc-typehints" @@ -102,8 +107,8 @@ def docs(session) -> None: session.run("python", "-m", "pip", "install", "sphinx-autodoc-typehints") -@nox.session() -def generate(session) -> None: +@nox.session() # type: ignore +def generate(session: Any) -> None: session.install("-rdev-requirements.txt") session.run("python", "utils/generate-api.py") format(session) diff --git a/opensearchpy/__init__.py b/opensearchpy/__init__.py index 3dcd7389..e9ef6485 100644 --- a/opensearchpy/__init__.py +++ b/opensearchpy/__init__.py @@ -256,4 +256,5 @@ "AsyncTransport", "AsyncOpenSearch", "AsyncHttpConnection", + "__versionstr__", ] diff --git a/opensearchpy/_async/helpers/document.py b/opensearchpy/_async/helpers/document.py index 25196e01..83349f7e 100644 --- a/opensearchpy/_async/helpers/document.py +++ b/opensearchpy/_async/helpers/document.py @@ -10,7 +10,7 @@ import collections.abc as collections_abc from fnmatch import fnmatch -from typing import Any, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Tuple, Type from six import add_metaclass @@ -128,9 +128,7 @@ def __repr__(self) -> str: ) @classmethod - def search( - cls, using: Optional[AsyncOpenSearch] = None, index: Optional[str] = None - ) -> AsyncSearch: + def search(cls, using: Any = None, index: Any = None) -> AsyncSearch: """ Create an :class:`~opensearchpy.AsyncSearch` instance that will search over this ``Document``. @@ -142,9 +140,9 @@ def search( @classmethod async def get( # type: ignore cls, - id: str, - using: Optional[AsyncOpenSearch] = None, - index: Optional[str] = None, + id: Any, + using: Any = None, + index: Any = None, **kwargs: Any, ) -> Any: """ @@ -189,7 +187,7 @@ async def exists( @classmethod async def mget( cls, - docs: Sequence[str], + docs: Any, using: Optional[AsyncOpenSearch] = None, index: Optional[str] = None, raise_on_error: Optional[bool] = True, diff --git a/opensearchpy/_async/helpers/index.py b/opensearchpy/_async/helpers/index.py index ea06f316..4f2a9918 100644 --- a/opensearchpy/_async/helpers/index.py +++ b/opensearchpy/_async/helpers/index.py @@ -59,7 +59,7 @@ async def save(self, using: Any = None) -> Any: class AsyncIndex(object): - def __init__(self, name: Any, using: str = "default") -> None: + def __init__(self, name: Any, using: Any = "default") -> None: """ :arg name: name of the index :arg using: connection alias to use, defaults to ``'default'`` diff --git a/opensearchpy/_async/http_aiohttp.py b/opensearchpy/_async/http_aiohttp.py index 34819970..3c7010ed 100644 --- a/opensearchpy/_async/http_aiohttp.py +++ b/opensearchpy/_async/http_aiohttp.py @@ -69,7 +69,7 @@ async def close(self) -> None: class AIOHttpConnection(AsyncConnection): - session: Optional[aiohttp.ClientSession] + session: aiohttp.ClientSession ssl_assert_fingerprint: Optional[str] def __init__( diff --git a/opensearchpy/client/utils.py b/opensearchpy/client/utils.py index 3ae204e6..d2196d11 100644 --- a/opensearchpy/client/utils.py +++ b/opensearchpy/client/utils.py @@ -32,7 +32,7 @@ import weakref from datetime import date, datetime from functools import wraps -from typing import Any, Callable +from typing import Any, Callable, Optional from opensearchpy.serializer import Serializer @@ -185,16 +185,16 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any: return _wrapper -def _bulk_body(serializer: Serializer, body: str) -> str: +def _bulk_body(serializer: Optional[Serializer], body: Any) -> Any: # if not passed in a string, serialize items and join by newline - if not isinstance(body, string_types): + if serializer and not isinstance(body, str): body = "\n".join(map(serializer.dumps, body)) # bulk body must end with a newline if isinstance(body, bytes): if not body.endswith(b"\n"): body += b"\n" - elif isinstance(body, string_types) and not body.endswith("\n"): + elif isinstance(body, str) and not body.endswith("\n"): body += "\n" return body diff --git a/opensearchpy/connection/async_connections.py b/opensearchpy/connection/async_connections.py index 87467ae0..670bbaeb 100644 --- a/opensearchpy/connection/async_connections.py +++ b/opensearchpy/connection/async_connections.py @@ -18,6 +18,8 @@ class AsyncConnections(object): + _conns: Any + """ Class responsible for holding connections to different clusters. Used as a singleton in this module. diff --git a/opensearchpy/connection_pool.py b/opensearchpy/connection_pool.py index defef6f5..378b91b3 100644 --- a/opensearchpy/connection_pool.py +++ b/opensearchpy/connection_pool.py @@ -124,7 +124,7 @@ class ConnectionPool(object): connections: Any orig_connections: Tuple[Connection, ...] dead: Any - dead_count: Dict[Connection, int] + dead_count: Dict[Any, int] dead_timeout: float timeout_cutoff: int selector: Any @@ -173,7 +173,7 @@ def __init__( self.selector = selector_class(dict(connections)) # type: ignore - def mark_dead(self, connection: Connection, now: Optional[float] = None) -> None: + def mark_dead(self, connection: Any, now: Optional[float] = None) -> None: """ Mark the connection as dead (failed). Remove it from the live pool and put it on a timeout. @@ -203,7 +203,7 @@ def mark_dead(self, connection: Connection, now: Optional[float] = None) -> None timeout, ) - def mark_live(self, connection: Connection) -> None: + def mark_live(self, connection: Any) -> None: """ Mark connection as healthy after a resurrection. Resets the fail counter for the connection. diff --git a/opensearchpy/helpers/actions.py b/opensearchpy/helpers/actions.py index 39e3cdaf..7f8ced35 100644 --- a/opensearchpy/helpers/actions.py +++ b/opensearchpy/helpers/actions.py @@ -503,12 +503,12 @@ def _setup_queues(self) -> None: def scan( client: Any, query: Any = None, - scroll: str = "5m", - raise_on_error: bool = True, - preserve_order: bool = False, - size: int = 1000, + scroll: Optional[str] = "5m", + raise_on_error: Optional[bool] = True, + preserve_order: Optional[bool] = False, + size: Optional[int] = 1000, request_timeout: Optional[float] = None, - clear_scroll: bool = True, + clear_scroll: Optional[bool] = True, scroll_kwargs: Any = None, **kwargs: Any ) -> Any: diff --git a/opensearchpy/helpers/asyncsigner.py b/opensearchpy/helpers/asyncsigner.py index bd84e09e..8dee4fee 100644 --- a/opensearchpy/helpers/asyncsigner.py +++ b/opensearchpy/helpers/asyncsigner.py @@ -8,7 +8,7 @@ # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. -from typing import Dict, Union +from typing import Any, Dict, Optional, Union class AWSV4SignerAsyncAuth: @@ -16,7 +16,7 @@ class AWSV4SignerAsyncAuth: AWS V4 Request Signer for Async Requests. """ - def __init__(self, credentials, region: str, service: str = "es") -> None: # type: ignore + def __init__(self, credentials: Any, region: str, service: str = "es") -> None: if not credentials: raise ValueError("Credentials cannot be empty") self.credentials = credentials @@ -30,12 +30,20 @@ def __init__(self, credentials, region: str, service: str = "es") -> None: # ty self.service = service def __call__( - self, method: str, url: str, query_string: str, body: Union[str, bytes] + self, + method: str, + url: str, + query_string: Optional[str] = None, + body: Optional[Union[str, bytes]] = None, ) -> Dict[str, str]: return self._sign_request(method, url, query_string, body) def _sign_request( - self, method: str, url: str, query_string: str, body: Union[str, bytes] + self, + method: str, + url: str, + query_string: Optional[str], + body: Optional[Union[str, bytes]], ) -> Dict[str, str]: """ This method helps in signing the request by injecting the required headers. diff --git a/opensearchpy/helpers/field.py b/opensearchpy/helpers/field.py index 4881e819..4ffd21d8 100644 --- a/opensearchpy/helpers/field.py +++ b/opensearchpy/helpers/field.py @@ -268,9 +268,7 @@ class Date(Field): name: Optional[str] = "date" _coerce: bool = True - def __init__( - self, default_timezone: None = None, *args: Any, **kwargs: Any - ) -> None: + def __init__(self, default_timezone: Any = None, *args: Any, **kwargs: Any) -> None: """ :arg default_timezone: timezone that will be automatically used for tz-naive values May be instance of `datetime.tzinfo` or string containing TZ offset diff --git a/opensearchpy/helpers/index.py b/opensearchpy/helpers/index.py index e96136b2..3fbb475a 100644 --- a/opensearchpy/helpers/index.py +++ b/opensearchpy/helpers/index.py @@ -78,7 +78,7 @@ def save(self, using: Any = None) -> Any: class Index(object): - def __init__(self, name: Any, using: str = "default") -> None: + def __init__(self, name: Any, using: Any = "default") -> None: """ :arg name: name of the index :arg using: connection alias to use, defaults to ``'default'`` diff --git a/opensearchpy/helpers/query.py b/opensearchpy/helpers/query.py index dc2db8a7..e299f94a 100644 --- a/opensearchpy/helpers/query.py +++ b/opensearchpy/helpers/query.py @@ -31,12 +31,11 @@ # 'SF' looks unused but the test suite assumes it's available # from this module so others are liable to do so as well. -from ..helpers.function import SF # noqa: F401 -from ..helpers.function import ScoreFunction +from ..helpers.function import SF, ScoreFunction from .utils import DslBase -def Q(name_or_query: str = "match_all", **params: Any) -> Any: +def Q(name_or_query: Any = "match_all", **params: Any) -> Any: # {"match": {"title": "python"}} if isinstance(name_or_query, collections_abc.Mapping): if params: @@ -521,3 +520,6 @@ class ParentId(Query): class Wrapper(Query): name = "wrapper" + + +__all__ = ["SF"] diff --git a/opensearchpy/helpers/search.py b/opensearchpy/helpers/search.py index 46ba9da9..069f4c89 100644 --- a/opensearchpy/helpers/search.py +++ b/opensearchpy/helpers/search.py @@ -864,3 +864,6 @@ def execute(self, ignore_cache: Any = False, raise_on_error: Any = True) -> Any: self._response = out return self._response + + +__all__ = ["Q"] diff --git a/opensearchpy/helpers/utils.py b/opensearchpy/helpers/utils.py index e17b89a6..2a9f19da 100644 --- a/opensearchpy/helpers/utils.py +++ b/opensearchpy/helpers/utils.py @@ -284,7 +284,7 @@ def get_dsl_class(cls: Any, name: Any, default: Optional[bool] = None) -> Any: "DSL class `{}` does not exist in {}.".format(name, cls._type_name) ) - def __init__(self, _expand__to_dot: bool = EXPAND__TO_DOT, **params: Any) -> None: + def __init__(self, _expand__to_dot: Any = EXPAND__TO_DOT, **params: Any) -> None: self._params = {} for pname, pvalue in iteritems(params): if "__" in pname and _expand__to_dot: @@ -438,6 +438,8 @@ def __init__( class ObjectBase(AttrDict): + _doc_type: Any + def __init__(self, meta: Any = None, **kwargs: Any) -> None: meta = meta or {} for k in list(kwargs): diff --git a/opensearchpy/transport.py b/opensearchpy/transport.py index 583d9ba7..44962542 100644 --- a/opensearchpy/transport.py +++ b/opensearchpy/transport.py @@ -373,7 +373,7 @@ def perform_request( method: str, url: str, params: Optional[Mapping[str, Any]] = None, - body: Optional[bytes] = None, + body: Any = None, timeout: Optional[Union[int, float]] = None, ignore: Collection[int] = (), headers: Optional[Mapping[str, str]] = None, diff --git a/samples/bulk/bulk-array.py b/samples/bulk/bulk-array.py index 1859d541..5191a291 100755 --- a/samples/bulk/bulk-array.py +++ b/samples/bulk/bulk-array.py @@ -12,6 +12,7 @@ import os +from typing import Any from opensearchpy import OpenSearch @@ -45,7 +46,7 @@ ) # index data -data = [] +data: Any = [] for i in range(100): data.append({"index": {"_index": index_name, "_id": i}}) data.append({"value": i}) diff --git a/samples/hello/hello-async.py b/samples/hello/hello-async.py index 9975f575..8606a17d 100755 --- a/samples/hello/hello-async.py +++ b/samples/hello/hello-async.py @@ -16,7 +16,7 @@ from opensearchpy import AsyncOpenSearch -async def main(): +async def main() -> None: # connect to OpenSearch host = "localhost" port = 9200 diff --git a/samples/json/json-hello-async.py b/samples/json/json-hello-async.py index b9105d35..fbadece6 100755 --- a/samples/json/json-hello-async.py +++ b/samples/json/json-hello-async.py @@ -16,7 +16,7 @@ from opensearchpy import AsyncOpenSearch -async def main(): +async def main() -> None: # connect to OpenSearch host = "localhost" port = 9200 diff --git a/samples/knn/knn-async-basics.py b/samples/knn/knn-async-basics.py index a7bb9d2f..aa0acf6e 100755 --- a/samples/knn/knn-async-basics.py +++ b/samples/knn/knn-async-basics.py @@ -18,7 +18,7 @@ from opensearchpy import AsyncHttpConnection, AsyncOpenSearch, helpers -async def main(): +async def main() -> None: # connect to an instance of OpenSearch host = os.getenv("HOST", default="localhost") port = int(os.getenv("PORT", 9200)) diff --git a/test_opensearchpy/TestHttpServer.py b/test_opensearchpy/TestHttpServer.py index ba83e041..3d8b31fb 100644 --- a/test_opensearchpy/TestHttpServer.py +++ b/test_opensearchpy/TestHttpServer.py @@ -11,10 +11,11 @@ import json import threading from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any class TestHTTPRequestHandler(BaseHTTPRequestHandler): - def do_GET(self): + def do_GET(self) -> None: headers = self.headers if self.path == "/redirect": @@ -40,6 +41,7 @@ def do_GET(self): class TestHTTPServer(HTTPServer): __test__ = False + _server_thread: Any def __init__(self, host: str = "localhost", port: int = 8080) -> None: super().__init__((host, port), TestHTTPRequestHandler) diff --git a/test_opensearchpy/run_tests.py b/test_opensearchpy/run_tests.py index de93adc7..b37fd598 100755 --- a/test_opensearchpy/run_tests.py +++ b/test_opensearchpy/run_tests.py @@ -37,6 +37,7 @@ import sys from os import environ from os.path import abspath, dirname, exists, join, pardir +from typing import Any def fetch_opensearch_repo() -> None: @@ -88,8 +89,8 @@ def fetch_opensearch_repo() -> None: subprocess.check_call("cd %s && git fetch origin %s" % (repo_path, sha), shell=True) -def run_all(argv: None = None) -> None: - sys.exitfunc = lambda: sys.stderr.write("Shutting down....\n") +def run_all(argv: Any = None) -> None: + sys.exitfunc = lambda: sys.stderr.write("Shutting down....\n") # type: ignore # fetch yaml tests anywhere that's not GitHub Actions if "GITHUB_ACTION" not in environ: fetch_opensearch_repo() diff --git a/test_opensearchpy/test_async/test_connection.py b/test_opensearchpy/test_async/test_connection.py index e72a2358..9413d0e8 100644 --- a/test_opensearchpy/test_async/test_connection.py +++ b/test_opensearchpy/test_async/test_connection.py @@ -32,6 +32,7 @@ import ssl import warnings from platform import python_version +from typing import Any import aiohttp import pytest @@ -52,29 +53,29 @@ class TestAIOHttpConnection: async def _get_mock_connection( self, - connection_params={}, + connection_params: Any = {}, response_code: int = 200, response_body: bytes = b"{}", - response_headers={}, - ): + response_headers: Any = {}, + ) -> Any: con = AIOHttpConnection(**connection_params) await con._create_aiohttp_session() - def _dummy_request(*args, **kwargs): + def _dummy_request(*args: Any, **kwargs: Any) -> Any: class DummyResponse: - async def __aenter__(self, *_, **__): + async def __aenter__(self, *_: Any, **__: Any) -> Any: return self - async def __aexit__(self, *_, **__): + async def __aexit__(self, *_: Any, **__: Any) -> None: pass - async def text(self): + async def text(self) -> Any: return response_body.decode("utf-8", "surrogatepass") - dummy_response = DummyResponse() + dummy_response: Any = DummyResponse() dummy_response.headers = CIMultiDict(**response_headers) dummy_response.status = response_code - _dummy_request.call_args = (args, kwargs) + _dummy_request.call_args = (args, kwargs) # type: ignore return dummy_response con.session.request = _dummy_request @@ -231,6 +232,7 @@ async def test_no_warning_when_using_ssl_context(self) -> None: assert w == [], str([x.message for x in w]) async def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self) -> None: + kwargs: Any for kwargs in ( {"ssl_show_warn": False}, {"ssl_show_warn": True}, @@ -253,26 +255,28 @@ async def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self) -> N ) @patch("ssl.SSLContext.load_verify_locations") - async def test_uses_given_ca_certs(self, load_verify_locations, tmp_path) -> None: + async def test_uses_given_ca_certs( + self, load_verify_locations: Any, tmp_path: Any + ) -> None: path = tmp_path / "ca_certs.pem" path.touch() AIOHttpConnection(use_ssl=True, ca_certs=str(path)) load_verify_locations.assert_called_once_with(cafile=str(path)) @patch("ssl.SSLContext.load_verify_locations") - async def test_uses_default_ca_certs(self, load_verify_locations) -> None: + async def test_uses_default_ca_certs(self, load_verify_locations: Any) -> None: AIOHttpConnection(use_ssl=True) load_verify_locations.assert_called_once_with( cafile=Connection.default_ca_certs() ) @patch("ssl.SSLContext.load_verify_locations") - async def test_uses_no_ca_certs(self, load_verify_locations) -> None: + async def test_uses_no_ca_certs(self, load_verify_locations: Any) -> None: AIOHttpConnection(use_ssl=True, verify_certs=False) load_verify_locations.assert_not_called() async def test_trust_env(self) -> None: - con = AIOHttpConnection(trust_env=True) + con: Any = AIOHttpConnection(trust_env=True) await con._create_aiohttp_session() assert con._trust_env is True @@ -286,7 +290,7 @@ async def test_trust_env_default_value_is_false(self) -> None: assert con.session.trust_env is False @patch("opensearchpy.connection.base.logger") - async def test_uncompressed_body_logged(self, logger) -> None: + async def test_uncompressed_body_logged(self, logger: Any) -> None: con = await self._get_mock_connection(connection_params={"http_compress": True}) await con.perform_request("GET", "/", body=b'{"example": "body"}') @@ -302,11 +306,11 @@ async def test_surrogatepass_into_bytes(self) -> None: status, headers, data = await con.perform_request("GET", "/") assert u"你好\uda6a" == data # fmt: skip - @pytest.mark.parametrize("exception_cls", reraise_exceptions) - async def test_recursion_error_reraised(self, exception_cls) -> None: + @pytest.mark.parametrize("exception_cls", reraise_exceptions) # type: ignore + async def test_recursion_error_reraised(self, exception_cls: Any) -> None: conn = AIOHttpConnection() - def request_raise(*_, **__): + def request_raise(*_: Any, **__: Any) -> Any: raise exception_cls("Wasn't modified!") await conn._create_aiohttp_session() @@ -334,6 +338,8 @@ async def test_json_errors_are_parsed(self) -> None: class TestConnectionHttpServer: """Tests the HTTP connection implementations against a live server E2E""" + server: Any + @classmethod def setup_class(cls) -> None: # Start server @@ -345,7 +351,7 @@ def teardown_class(cls) -> None: # Stop server cls.server.stop() - async def httpserver(self, conn, **kwargs): + async def httpserver(self, conn: Any, **kwargs: Any) -> Any: status, headers, data = await conn.perform_request("GET", "/", **kwargs) data = json.loads(data) return (status, data) diff --git a/test_opensearchpy/test_async/test_helpers/conftest.py b/test_opensearchpy/test_async/test_helpers/conftest.py index f24b8a48..bd1776ab 100644 --- a/test_opensearchpy/test_async/test_helpers/conftest.py +++ b/test_opensearchpy/test_async/test_helpers/conftest.py @@ -9,6 +9,8 @@ # GitHub history for details. +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator from mock import Mock @@ -19,18 +21,18 @@ pytestmark: MarkDecorator = pytest.mark.asyncio -@fixture -async def mock_client(dummy_response): +@fixture # type: ignore +async def mock_client(dummy_response: Any) -> Any: client = Mock() client.search.return_value = dummy_response await add_connection("mock", client) yield client - async_connections._conn = {} + async_connections._conns = {} async_connections._kwargs = {} -@fixture -def dummy_response(): +@fixture # type: ignore +def dummy_response() -> Any: return { "_shards": {"failed": 0, "successful": 10, "total": 10}, "hits": { @@ -78,8 +80,8 @@ def dummy_response(): } -@fixture -def aggs_search(): +@fixture # type: ignore +def aggs_search() -> Any: from opensearchpy._async.helpers.search import AsyncSearch s = AsyncSearch(index="flat-git") @@ -93,8 +95,8 @@ def aggs_search(): return s -@fixture -def aggs_data(): +@fixture # type: ignore +def aggs_data() -> Any: return { "took": 4, "timed_out": False, diff --git a/test_opensearchpy/test_async/test_helpers/test_document.py b/test_opensearchpy/test_async/test_helpers/test_document.py index d13c7272..d6ef0128 100644 --- a/test_opensearchpy/test_async/test_helpers/test_document.py +++ b/test_opensearchpy/test_async/test_helpers/test_document.py @@ -15,6 +15,7 @@ import pickle from datetime import datetime from hashlib import sha256 +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator @@ -31,25 +32,25 @@ class MyInner(InnerDoc): - old_field = field.Text() + old_field: Any = field.Text() class MyDoc(document.AsyncDocument): - title = field.Keyword() - name = field.Text() - created_at = field.Date() - inner = field.Object(MyInner) + title: Any = field.Keyword() + name: Any = field.Text() + created_at: Any = field.Date() + inner: Any = field.Object(MyInner) class MySubDoc(MyDoc): - name = field.Keyword() + name: Any = field.Keyword() class Index: name = "default-index" class MyDoc2(document.AsyncDocument): - extra = field.Long() + extra: Any = field.Long() class MyMultiSubDoc(MyDoc2, MySubDoc): @@ -57,19 +58,19 @@ class MyMultiSubDoc(MyDoc2, MySubDoc): class Comment(InnerDoc): - title = field.Text() - tags = field.Keyword(multi=True) + title: Any = field.Text() + tags: Any = field.Keyword(multi=True) class DocWithNested(document.AsyncDocument): - comments = field.Nested(Comment) + comments: Any = field.Nested(Comment) class Index: name = "test-doc-with-nested" class SimpleCommit(document.AsyncDocument): - files = field.Text(multi=True) + files: Any = field.Text(multi=True) class Index: name = "test-git" @@ -80,48 +81,54 @@ class Secret(str): class SecretField(field.CustomField): - builtin_type = "text" + builtin_type: Any = "text" - def _serialize(self, data): + def _serialize(self, data: Any) -> Any: return codecs.encode(data, "rot_13") - def _deserialize(self, data): + def _deserialize(self, data: Any) -> Any: if isinstance(data, Secret): return data return Secret(codecs.decode(data, "rot_13")) class SecretDoc(document.AsyncDocument): - title = SecretField(index="no") + title: Any = SecretField(index="no") class Index: name = "test-secret-doc" class NestedSecret(document.AsyncDocument): - secrets = field.Nested(SecretDoc) + secrets: Any = field.Nested(SecretDoc) class Index: name = "test-nested-secret" + _index: Any + class OptionalObjectWithRequiredField(document.AsyncDocument): - comments = field.Nested(properties={"title": field.Keyword(required=True)}) + comments: Any = field.Nested(properties={"title": field.Keyword(required=True)}) class Index: name = "test-required" + _index: Any + class Host(document.AsyncDocument): - ip = field.Ip() + ip: Any = field.Ip() class Index: name = "test-host" + _index: Any + async def test_range_serializes_properly() -> None: class D(document.AsyncDocument): - lr = field.LongRange() + lr: Any = field.LongRange() d = D(lr=Range(lt=42)) assert 40 in d.lr @@ -200,7 +207,7 @@ async def test_assigning_attrlist_to_field() -> None: async def test_optional_inner_objects_are_not_validated_if_missing() -> None: - d = OptionalObjectWithRequiredField() + d: Any = OptionalObjectWithRequiredField() assert d.full_clean() is None @@ -253,13 +260,15 @@ async def test_null_value_for_object() -> None: assert d.inner is None -async def test_inherited_doc_types_can_override_index(): +async def test_inherited_doc_types_can_override_index() -> None: class MyDocDifferentIndex(MySubDoc): + _index: Any + class Index: - name = "not-default-index" - settings = {"number_of_replicas": 0} - aliases = {"a": {}} - analyzers = [analyzer("my_analizer", tokenizer="keyword")] + name: Any = "not-default-index" + settings: Any = {"number_of_replicas": 0} + aliases: Any = {"a": {}} + analyzers: Any = [analyzer("my_analizer", tokenizer="keyword")] assert MyDocDifferentIndex._index._name == "not-default-index" assert MyDocDifferentIndex()._get_index() == "not-default-index" @@ -285,7 +294,7 @@ class Index: } -async def test_to_dict_with_meta(): +async def test_to_dict_with_meta() -> None: d = MySubDoc(title="hello") d.meta.routing = "some-parent" @@ -296,7 +305,7 @@ async def test_to_dict_with_meta(): } == d.to_dict(True) -async def test_to_dict_with_meta_includes_custom_index(): +async def test_to_dict_with_meta_includes_custom_index() -> None: d = MySubDoc(title="hello") d.meta.index = "other-index" @@ -340,7 +349,7 @@ async def test_meta_is_accessible_even_on_empty_doc() -> None: d.meta -async def test_meta_field_mapping(): +async def test_meta_field_mapping() -> None: class User(document.AsyncDocument): username = field.Text() @@ -372,17 +381,17 @@ class Blog(document.AsyncDocument): async def test_docs_with_properties() -> None: class User(document.AsyncDocument): - pwd_hash = field.Text() + pwd_hash: Any = field.Text() - def check_password(self, pwd): + def check_password(self, pwd: Any) -> Any: return sha256(pwd).hexdigest() == self.pwd_hash @property - def password(self): + def password(self) -> Any: raise AttributeError("readonly") @password.setter - def password(self, pwd): + def password(self, pwd: Any) -> None: self.pwd_hash = sha256(pwd).hexdigest() u = User(pwd_hash=sha256(b"secret").hexdigest()) @@ -424,8 +433,8 @@ async def test_nested_defaults_to_list_and_can_be_updated() -> None: assert {"comments": [{"title": "hello World!"}]} == md.to_dict() -async def test_to_dict_is_recursive_and_can_cope_with_multi_values(): - md = MyDoc(name=["a", "b", "c"]) +async def test_to_dict_is_recursive_and_can_cope_with_multi_values() -> None: + md: Any = MyDoc(name=["a", "b", "c"]) md.inner = [MyInner(old_field="of1"), MyInner(old_field="of2")] assert isinstance(md.inner[0], MyInner) @@ -437,12 +446,12 @@ async def test_to_dict_is_recursive_and_can_cope_with_multi_values(): async def test_to_dict_ignores_empty_collections() -> None: - md = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) + md: Any = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) assert {"name": "", "count": 0, "valid": False} == md.to_dict() -async def test_declarative_mapping_definition(): +async def test_declarative_mapping_definition() -> None: assert issubclass(MyDoc, document.AsyncDocument) assert hasattr(MyDoc, "_doc_type") assert { @@ -455,7 +464,7 @@ async def test_declarative_mapping_definition(): } == MyDoc._doc_type.mapping.to_dict() -async def test_you_can_supply_own_mapping_instance(): +async def test_you_can_supply_own_mapping_instance() -> None: class MyD(document.AsyncDocument): title = field.Text() @@ -469,9 +478,9 @@ class Meta: } == MyD._doc_type.mapping.to_dict() -async def test_document_can_be_created_dynamically(): +async def test_document_can_be_created_dynamically() -> None: n = datetime.now() - md = MyDoc(title="hello") + md: Any = MyDoc(title="hello") md.name = "My Fancy Document!" md.created_at = n @@ -491,13 +500,13 @@ async def test_document_can_be_created_dynamically(): async def test_invalid_date_will_raise_exception() -> None: - md = MyDoc() + md: Any = MyDoc() md.created_at = "not-a-date" with raises(ValidationException): md.full_clean() -async def test_document_inheritance(): +async def test_document_inheritance() -> None: assert issubclass(MySubDoc, MyDoc) assert issubclass(MySubDoc, document.AsyncDocument) assert hasattr(MySubDoc, "_doc_type") @@ -511,7 +520,7 @@ async def test_document_inheritance(): } == MySubDoc._doc_type.mapping.to_dict() -async def test_child_class_can_override_parent(): +async def test_child_class_can_override_parent() -> None: class A(document.AsyncDocument): o = field.Object(dynamic=False, properties={"a": field.Text()}) @@ -530,7 +539,7 @@ class B(A): async def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: - md = MySubDoc(meta={"id": 42}, name="My First doc!") + md: Any = MySubDoc(meta={"id": 42}, name="My First doc!") md.meta.index = "my-index" assert md.meta.index == "my-index" @@ -539,7 +548,7 @@ async def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: assert {"id": 42, "index": "my-index"} == md.meta.to_dict() -async def test_index_inheritance(): +async def test_index_inheritance() -> None: assert issubclass(MyMultiSubDoc, MySubDoc) assert issubclass(MyMultiSubDoc, MyDoc2) assert issubclass(MyMultiSubDoc, document.AsyncDocument) @@ -558,31 +567,31 @@ async def test_index_inheritance(): async def test_meta_fields_can_be_set_directly_in_init() -> None: p = object() - md = MyDoc(_id=p, title="Hello World!") + md: Any = MyDoc(_id=p, title="Hello World!") assert md.meta.id is p -async def test_save_no_index(mock_client) -> None: - md = MyDoc() +async def test_save_no_index(mock_client: Any) -> None: + md: Any = MyDoc() with raises(ValidationException): await md.save(using="mock") -async def test_delete_no_index(mock_client) -> None: - md = MyDoc() +async def test_delete_no_index(mock_client: Any) -> None: + md: Any = MyDoc() with raises(ValidationException): await md.delete(using="mock") async def test_update_no_fields() -> None: - md = MyDoc() + md: Any = MyDoc() with raises(IllegalOperation): await md.update() -async def test_search_with_custom_alias_and_index(mock_client) -> None: - search_object = MyDoc.search( +async def test_search_with_custom_alias_and_index(mock_client: Any) -> None: + search_object: Any = MyDoc.search( using="staging", index=["custom_index1", "custom_index2"] ) @@ -590,8 +599,8 @@ async def test_search_with_custom_alias_and_index(mock_client) -> None: assert search_object._index == ["custom_index1", "custom_index2"] -async def test_from_opensearch_respects_underscored_non_meta_fields(): - doc = { +async def test_from_opensearch_respects_underscored_non_meta_fields() -> None: + doc: Any = { "_index": "test-index", "_id": "opensearch", "_score": 12.0, @@ -614,11 +623,11 @@ class Index: assert c._tagline == "You know, for search" -async def test_nested_and_object_inner_doc(): +async def test_nested_and_object_inner_doc() -> None: class MySubDocWithNested(MyDoc): nested_inner = field.Nested(MyInner) - props = MySubDocWithNested._doc_type.mapping.to_dict()["properties"] + props: Any = MySubDocWithNested._doc_type.mapping.to_dict()["properties"] assert props == { "created_at": {"type": "date"}, "inner": {"properties": {"old_field": {"type": "text"}}, "type": "object"}, diff --git a/test_opensearchpy/test_async/test_helpers/test_faceted_search.py b/test_opensearchpy/test_async/test_helpers/test_faceted_search.py index 58c936c0..b4fe2a8d 100644 --- a/test_opensearchpy/test_async/test_helpers/test_faceted_search.py +++ b/test_opensearchpy/test_async/test_helpers/test_faceted_search.py @@ -9,6 +9,7 @@ # GitHub history for details. from datetime import datetime +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator @@ -55,7 +56,7 @@ async def test_query_is_created_properly() -> None: } == s.to_dict() -async def test_query_is_created_properly_with_sort_tuple(): +async def test_query_is_created_properly_with_sort_tuple() -> None: bs = BlogSearch("python search", sort=("category", "-title")) s = bs.build_search() @@ -79,7 +80,7 @@ async def test_query_is_created_properly_with_sort_tuple(): } == s.to_dict() -async def test_filter_is_applied_to_search_but_not_relevant_facet(): +async def test_filter_is_applied_to_search_but_not_relevant_facet() -> None: bs = BlogSearch("python search", filters={"category": "opensearch"}) s = bs.build_search() @@ -102,7 +103,7 @@ async def test_filter_is_applied_to_search_but_not_relevant_facet(): } == s.to_dict() -async def test_filters_are_applied_to_search_ant_relevant_facets(): +async def test_filters_are_applied_to_search_ant_relevant_facets() -> None: bs = BlogSearch( "python search", filters={"category": "opensearch", "tags": ["python", "django"]}, @@ -168,8 +169,8 @@ async def test_date_histogram_facet_with_1970_01_01_date() -> None: ("interval", "1h"), ("fixed_interval", "1h"), ], -) -async def test_date_histogram_interval_types(interval_type, interval) -> None: +) # type: ignore +async def test_date_histogram_interval_types(interval_type: Any, interval: Any) -> None: dhf = DateHistogramFacet(field="@timestamp", **{interval_type: interval}) assert dhf.get_aggregation().to_dict() == { "date_histogram": { diff --git a/test_opensearchpy/test_async/test_helpers/test_index.py b/test_opensearchpy/test_async/test_helpers/test_index.py index 681b9cfe..e59d86ad 100644 --- a/test_opensearchpy/test_async/test_helpers/test_index.py +++ b/test_opensearchpy/test_async/test_helpers/test_index.py @@ -10,6 +10,7 @@ import string from random import choice +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator @@ -118,7 +119,7 @@ async def test_registered_doc_type_included_in_search() -> None: async def test_aliases_add_to_object() -> None: random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) - alias_dict = {random_alias: {}} + alias_dict: Any = {random_alias: {}} index = AsyncIndex("i", using="alias") index.aliases(**alias_dict) @@ -128,7 +129,7 @@ async def test_aliases_add_to_object() -> None: async def test_aliases_returned_from_to_dict() -> None: random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) - alias_dict = {random_alias: {}} + alias_dict: Any = {random_alias: {}} index = AsyncIndex("i", using="alias") index.aliases(**alias_dict) @@ -136,7 +137,7 @@ async def test_aliases_returned_from_to_dict() -> None: assert index._aliases == index.to_dict()["aliases"] == alias_dict -async def test_analyzers_added_to_object(): +async def test_analyzers_added_to_object() -> None: random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -152,7 +153,7 @@ async def test_analyzers_added_to_object(): } -async def test_analyzers_returned_from_to_dict(): +async def test_analyzers_returned_from_to_dict() -> None: random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" @@ -173,7 +174,7 @@ async def test_conflicting_analyzer_raises_error() -> None: i.analyzer("my_analyzer", tokenizer="keyword", filter=["lowercase", "stop"]) -async def test_index_template_can_have_order(): +async def test_index_template_can_have_order() -> None: i = AsyncIndex("i-*") it = i.as_template("i", order=2) diff --git a/test_opensearchpy/test_async/test_helpers/test_mapping.py b/test_opensearchpy/test_async/test_helpers/test_mapping.py index 6ae4c0b7..797c295f 100644 --- a/test_opensearchpy/test_async/test_helpers/test_mapping.py +++ b/test_opensearchpy/test_async/test_helpers/test_mapping.py @@ -24,7 +24,7 @@ async def test_mapping_can_has_fields() -> None: } == m.to_dict() -async def test_mapping_update_is_recursive(): +async def test_mapping_update_is_recursive() -> None: m1 = mapping.AsyncMapping() m1.field("title", "text") m1.field("author", "object") @@ -67,7 +67,7 @@ async def test_properties_can_iterate_over_all_the_fields() -> None: } -async def test_mapping_can_collect_all_analyzers_and_normalizers(): +async def test_mapping_can_collect_all_analyzers_and_normalizers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -140,7 +140,7 @@ async def test_mapping_can_collect_all_analyzers_and_normalizers(): assert json.loads(json.dumps(m.to_dict())) == m.to_dict() -async def test_mapping_can_collect_multiple_analyzers(): +async def test_mapping_can_collect_multiple_analyzers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", diff --git a/test_opensearchpy/test_async/test_helpers/test_search.py b/test_opensearchpy/test_async/test_helpers/test_search.py index c32a8c7c..1af617d7 100644 --- a/test_opensearchpy/test_async/test_helpers/test_search.py +++ b/test_opensearchpy/test_async/test_helpers/test_search.py @@ -9,6 +9,7 @@ # GitHub history for details. from copy import deepcopy +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator @@ -71,7 +72,7 @@ async def test_query_can_be_assigned_to() -> None: assert s.query._proxied is q -async def test_query_can_be_wrapped(): +async def test_query_can_be_wrapped() -> None: s = search.AsyncSearch().query("match", title="python") s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) @@ -142,7 +143,7 @@ async def test_aggs_allow_two_metric() -> None: } -async def test_aggs_get_copied_on_change(): +async def test_aggs_get_copied_on_change() -> None: s = search.AsyncSearch().query("match_all") s.aggs.bucket("per_tag", "terms", field="f").metric( "max_score", "max", field="score" @@ -155,7 +156,7 @@ async def test_aggs_get_copied_on_change(): s4 = s3._clone() s4.aggs.metric("max_score", "max", field="score") - d = { + d: Any = { "query": {"match_all": {}}, "aggs": { "per_tag": { @@ -218,7 +219,7 @@ class MyDocument(AsyncDocument): assert s._doc_type_map == {} -async def test_sort(): +async def test_sort() -> None: s = search.AsyncSearch() s = s.sort("fielda", "-fieldb") @@ -254,7 +255,7 @@ async def test_index() -> None: assert {"from": 3, "size": 1} == s[3].to_dict() -async def test_search_to_dict(): +async def test_search_to_dict() -> None: s = search.AsyncSearch() assert {} == s.to_dict() @@ -283,7 +284,7 @@ async def test_search_to_dict(): assert {"size": 5, "from": 42} == s.to_dict() -async def test_complex_example(): +async def test_complex_example() -> None: s = search.AsyncSearch() s = ( s.query("match", title="python") @@ -334,7 +335,7 @@ async def test_complex_example(): } == s.to_dict() -async def test_reverse(): +async def test_reverse() -> None: d = { "query": { "filtered": { @@ -406,7 +407,7 @@ async def test_source() -> None: ).source(["f1", "f2"]).to_dict() -async def test_source_on_clone(): +async def test_source_on_clone() -> None: assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, @@ -431,7 +432,7 @@ async def test_source_on_clear() -> None: ) -async def test_suggest_accepts_global_text(): +async def test_suggest_accepts_global_text() -> None: s = search.AsyncSearch.from_dict( { "suggest": { @@ -453,7 +454,7 @@ async def test_suggest_accepts_global_text(): } == s.to_dict() -async def test_suggest(): +async def test_suggest() -> None: s = search.AsyncSearch() s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) @@ -475,7 +476,7 @@ async def test_exclude() -> None: } == s.to_dict() -async def test_update_from_dict(): +async def test_update_from_dict() -> None: s = search.AsyncSearch() s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) s.update_from_dict({"_source": ["id", "name"]}) @@ -486,7 +487,7 @@ async def test_update_from_dict(): } == s.to_dict() -async def test_rescore_query_to_dict(): +async def test_rescore_query_to_dict() -> None: s = search.AsyncSearch(index="index-name") positive_query = Q( diff --git a/test_opensearchpy/test_async/test_helpers/test_update_by_query.py b/test_opensearchpy/test_async/test_helpers/test_update_by_query.py index b15983dc..52fc20c3 100644 --- a/test_opensearchpy/test_async/test_helpers/test_update_by_query.py +++ b/test_opensearchpy/test_async/test_helpers/test_update_by_query.py @@ -26,7 +26,7 @@ async def test_ubq_starts_with_no_query() -> None: assert ubq.query._proxied is None -async def test_ubq_to_dict(): +async def test_ubq_to_dict() -> None: ubq = update_by_query.AsyncUpdateByQuery() assert {} == ubq.to_dict() @@ -44,7 +44,7 @@ async def test_ubq_to_dict(): assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() -async def test_complex_example(): +async def test_complex_example() -> None: ubq = update_by_query.AsyncUpdateByQuery() ubq = ( ubq.query("match", title="python") @@ -95,7 +95,7 @@ async def test_exclude() -> None: } == ubq.to_dict() -async def test_reverse(): +async def test_reverse() -> None: d = { "query": { "filtered": { @@ -137,7 +137,7 @@ async def test_from_dict_doesnt_need_query() -> None: assert {"script": {"source": "test"}} == ubq.to_dict() -async def test_overwrite_script(): +async def test_overwrite_script() -> None: ubq = update_by_query.AsyncUpdateByQuery() ubq = ubq.script( source="ctx._source.likes += params.f", lang="painless", params={"f": 3} diff --git a/test_opensearchpy/test_async/test_http_connection.py b/test_opensearchpy/test_async/test_http_connection.py index 913a944d..febb231b 100644 --- a/test_opensearchpy/test_async/test_http_connection.py +++ b/test_opensearchpy/test_async/test_http_connection.py @@ -26,12 +26,14 @@ # under the License. +from typing import Any + import mock import pytest from _pytest.mark.structures import MarkDecorator from multidict import CIMultiDict -from opensearchpy._async._extra_imports import aiohttp +from opensearchpy._async._extra_imports import aiohttp # type: ignore from opensearchpy._async.compat import get_running_loop from opensearchpy.connection.http_async import AsyncHttpConnection @@ -52,15 +54,15 @@ def test_auth_as_string(self) -> None: assert c._http_auth.password, "password" def test_auth_as_callable(self) -> None: - def auth_fn(): + def auth_fn() -> None: pass c = AsyncHttpConnection(http_auth=auth_fn) assert callable(c._http_auth) @mock.patch("aiohttp.ClientSession.request", new_callable=mock.Mock) - async def test_basicauth_in_request_session(self, mock_request) -> None: - async def do_request(*args, **kwargs): + async def test_basicauth_in_request_session(self, mock_request: Any) -> None: + async def do_request(*args: Any, **kwargs: Any) -> Any: response_mock = mock.AsyncMock() response_mock.headers = CIMultiDict() response_mock.status = 200 @@ -90,13 +92,13 @@ async def do_request(*args, **kwargs): ) @mock.patch("aiohttp.ClientSession.request", new_callable=mock.Mock) - async def test_callable_in_request_session(self, mock_request) -> None: - def auth_fn(*args, **kwargs): + async def test_callable_in_request_session(self, mock_request: Any) -> None: + def auth_fn(*args: Any, **kwargs: Any) -> Any: return { "Test": "PASSED", } - async def do_request(*args, **kwargs): + async def do_request(*args: Any, **kwargs: Any) -> Any: response_mock = mock.AsyncMock() response_mock.headers = CIMultiDict() response_mock.status = 200 diff --git a/test_opensearchpy/test_async/test_plugins_client.py b/test_opensearchpy/test_async/test_plugins_client.py index 2364f0fa..d701892c 100644 --- a/test_opensearchpy/test_async/test_plugins_client.py +++ b/test_opensearchpy/test_async/test_plugins_client.py @@ -17,7 +17,8 @@ class TestPluginsClient(TestCase): async def test_plugins_client(self) -> None: with self.assertWarns(Warning) as w: client = AsyncOpenSearch() - client.plugins.__init__(client) # double-init + # testing double-init here + client.plugins.__init__(client) # type: ignore self.assertEqual( str(w.warnings[0].message), "Cannot load `alerting` directly to AsyncOpenSearch as it already exists. Use `AsyncOpenSearch.plugin.alerting` instead.", diff --git a/test_opensearchpy/test_async/test_server/conftest.py b/test_opensearchpy/test_async/test_server/conftest.py index 908313ee..87680b36 100644 --- a/test_opensearchpy/test_async/test_server/conftest.py +++ b/test_opensearchpy/test_async/test_server/conftest.py @@ -27,27 +27,28 @@ import asyncio +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator import opensearchpy -from opensearchpy.helpers.test import OPENSEARCH_URL +from opensearchpy.helpers.test import OPENSEARCH_URL # type: ignore from ...utils import wipe_cluster pytestmark: MarkDecorator = pytest.mark.asyncio -@pytest.fixture(scope="function") -async def async_client(): +@pytest.fixture(scope="function") # type: ignore +async def async_client() -> Any: client = None try: if not hasattr(opensearchpy, "AsyncOpenSearch"): pytest.skip("test requires 'AsyncOpenSearch'") kw = {"timeout": 3} - client = opensearchpy.AsyncOpenSearch(OPENSEARCH_URL, **kw) + client = opensearchpy.AsyncOpenSearch(OPENSEARCH_URL, **kw) # type: ignore # wait for yellow status for _ in range(100): diff --git a/test_opensearchpy/test_async/test_server/test_clients.py b/test_opensearchpy/test_async/test_server/test_clients.py index 41a07012..323532c5 100644 --- a/test_opensearchpy/test_async/test_server/test_clients.py +++ b/test_opensearchpy/test_async/test_server/test_clients.py @@ -28,6 +28,8 @@ from __future__ import unicode_literals +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator @@ -35,19 +37,19 @@ class TestUnicode: - async def test_indices_analyze(self, async_client) -> None: + async def test_indices_analyze(self, async_client: Any) -> None: await async_client.indices.analyze(body='{"text": "привет"}') class TestBulk: - async def test_bulk_works_with_string_body(self, async_client) -> None: + async def test_bulk_works_with_string_body(self, async_client: Any) -> None: docs = '{ "index" : { "_index" : "bulk_test_index", "_id" : "1" } }\n{"answer": 42}' response = await async_client.bulk(body=docs) assert response["errors"] is False assert len(response["items"]) == 1 - async def test_bulk_works_with_bytestring_body(self, async_client) -> None: + async def test_bulk_works_with_bytestring_body(self, async_client: Any) -> None: docs = b'{ "index" : { "_index" : "bulk_test_index", "_id" : "2" } }\n{"answer": 42}' response = await async_client.bulk(body=docs) @@ -57,7 +59,7 @@ async def test_bulk_works_with_bytestring_body(self, async_client) -> None: class TestYarlMissing: async def test_aiohttp_connection_works_without_yarl( - self, async_client, monkeypatch + self, async_client: Any, monkeypatch: Any ) -> None: # This is a defensive test case for if aiohttp suddenly stops using yarl. from opensearchpy._async import http_aiohttp diff --git a/test_opensearchpy/test_async/test_server/test_helpers/conftest.py b/test_opensearchpy/test_async/test_server/test_helpers/conftest.py index 36ea7a10..69282ead 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/conftest.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/conftest.py @@ -10,6 +10,7 @@ import re from datetime import datetime +from typing import Any import pytest from pytest import fixture @@ -34,32 +35,32 @@ pytestmark = pytest.mark.asyncio -@fixture(scope="function") -async def client(): +@fixture(scope="function") # type: ignore +async def client() -> Any: client = await get_test_client(verify_certs=False, http_auth=("admin", "admin")) await add_connection("default", client) return client -@fixture(scope="function") -async def opensearch_version(client): +@fixture(scope="function") # type: ignore +async def opensearch_version(client: Any) -> Any: info = await client.info() print(info) yield tuple( int(x) - for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") + for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore ) -@fixture -async def write_client(client): +@fixture # type: ignore +async def write_client(client: Any) -> Any: yield client await client.indices.delete("test-*", ignore=404) await client.indices.delete_template("test-template", ignore=404) -@fixture -async def data_client(client): +@fixture # type: ignore +async def data_client(client: Any) -> Any: # create mappings await create_git_index(client, "git") await create_flat_git_index(client, "flat-git") @@ -71,8 +72,8 @@ async def data_client(client): await client.indices.delete("flat-git", ignore=404) -@fixture -async def pull_request(write_client): +@fixture # type: ignore +async def pull_request(write_client: Any) -> Any: await PullRequest.init() pr = PullRequest( _id=42, @@ -95,8 +96,8 @@ async def pull_request(write_client): return pr -@fixture -async def setup_ubq_tests(client) -> str: +@fixture # type: ignore +async def setup_ubq_tests(client: Any) -> str: index = "test-git" await create_git_index(client, index) await async_bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True) diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_actions.py b/test_opensearchpy/test_async/test_server/test_helpers/test_actions.py index dee69819..b0c5375a 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_actions.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_actions.py @@ -27,7 +27,7 @@ import asyncio -from typing import Tuple +from typing import Any, List import pytest from mock import MagicMock, patch @@ -40,19 +40,19 @@ class AsyncMock(MagicMock): - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: Any, **kwargs: Any) -> Any: return super(AsyncMock, self).__call__(*args, **kwargs) - def __await__(self): + def __await__(self) -> Any: return self().__await__() class FailingBulkClient(object): def __init__( self, - client, - fail_at: Tuple[int] = (2,), - fail_with=TransportError(599, "Error!", {}), + client: Any, + fail_at: Any = (2,), + fail_with: TransportError = TransportError(599, "Error!", {}), ) -> None: self.client = client self._called = 0 @@ -60,7 +60,7 @@ def __init__( self.transport = client.transport self._fail_with = fail_with - async def bulk(self, *args, **kwargs): + async def bulk(self, *args: Any, **kwargs: Any) -> Any: self._called += 1 if self._called in self._fail_at: raise self._fail_with @@ -68,7 +68,7 @@ async def bulk(self, *args, **kwargs): class TestStreamingBulk(object): - async def test_actions_remain_unchanged(self, async_client) -> None: + async def test_actions_remain_unchanged(self, async_client: Any) -> None: actions1 = [{"_id": 1}, {"_id": 2}] async for ok, item in actions.async_streaming_bulk( async_client, actions1, index="test-index" @@ -76,7 +76,7 @@ async def test_actions_remain_unchanged(self, async_client) -> None: assert ok assert [{"_id": 1}, {"_id": 2}] == actions1 - async def test_all_documents_get_inserted(self, async_client) -> None: + async def test_all_documents_get_inserted(self, async_client: Any) -> None: docs = [{"answer": x, "_id": x} for x in range(100)] async for ok, item in actions.async_streaming_bulk( async_client, docs, index="test-index", refresh=True @@ -88,13 +88,13 @@ async def test_all_documents_get_inserted(self, async_client) -> None: "_source" ] - async def test_documents_data_types(self, async_client): - async def async_gen(): + async def test_documents_data_types(self, async_client: Any) -> None: + async def async_gen() -> Any: for x in range(100): await asyncio.sleep(0) yield {"answer": x, "_id": x} - def sync_gen(): + def sync_gen() -> Any: for x in range(100): yield {"answer": x, "_id": x} @@ -123,7 +123,7 @@ def sync_gen(): ] async def test_all_errors_from_chunk_are_raised_on_failure( - self, async_client + self, async_client: Any ) -> None: await async_client.indices.create( "i", @@ -144,7 +144,7 @@ async def test_all_errors_from_chunk_are_raised_on_failure( else: assert False, "exception should have been raised" - async def test_different_op_types(self, async_client): + async def test_different_op_types(self, async_client: Any) -> None: await async_client.index(index="i", id=45, body={}) await async_client.index(index="i", id=42, body={}) docs = [ @@ -159,7 +159,7 @@ async def test_different_op_types(self, async_client): assert {"answer": 42} == (await async_client.get(index="i", id=42))["_source"] assert {"f": "v"} == (await async_client.get(index="i", id=47))["_source"] - async def test_transport_error_can_becaught(self, async_client): + async def test_transport_error_can_becaught(self, async_client: Any) -> None: failing_client = FailingBulkClient(async_client) docs = [ {"_index": "i", "_id": 47, "f": "v"}, @@ -193,7 +193,7 @@ async def test_transport_error_can_becaught(self, async_client): } } == results[1][1] - async def test_rejected_documents_are_retried(self, async_client) -> None: + async def test_rejected_documents_are_retried(self, async_client: Any) -> None: failing_client = FailingBulkClient( async_client, fail_with=TransportError(429, "Rejected!", {}) ) @@ -222,7 +222,7 @@ async def test_rejected_documents_are_retried(self, async_client) -> None: assert 4 == failing_client._called async def test_rejected_documents_are_retried_at_most_max_retries_times( - self, async_client + self, async_client: Any ) -> None: failing_client = FailingBulkClient( async_client, fail_at=(1, 2), fail_with=TransportError(429, "Rejected!", {}) @@ -253,7 +253,7 @@ async def test_rejected_documents_are_retried_at_most_max_retries_times( assert 4 == failing_client._called async def test_transport_error_is_raised_with_max_retries( - self, async_client + self, async_client: Any ) -> None: failing_client = FailingBulkClient( async_client, @@ -261,7 +261,7 @@ async def test_transport_error_is_raised_with_max_retries( fail_with=TransportError(429, "Rejected!", {}), ) - async def streaming_bulk(): + async def streaming_bulk() -> Any: results = [ x async for x in actions.async_streaming_bulk( @@ -280,7 +280,7 @@ async def streaming_bulk(): class TestBulk(object): - async def test_bulk_works_with_single_item(self, async_client) -> None: + async def test_bulk_works_with_single_item(self, async_client: Any) -> None: docs = [{"answer": 42, "_id": 1}] success, failed = await actions.async_bulk( async_client, docs, index="test-index", refresh=True @@ -293,7 +293,7 @@ async def test_bulk_works_with_single_item(self, async_client) -> None: "_source" ] - async def test_all_documents_get_inserted(self, async_client) -> None: + async def test_all_documents_get_inserted(self, async_client: Any) -> None: docs = [{"answer": x, "_id": x} for x in range(100)] success, failed = await actions.async_bulk( async_client, docs, index="test-index", refresh=True @@ -306,7 +306,7 @@ async def test_all_documents_get_inserted(self, async_client) -> None: "_source" ] - async def test_stats_only_reports_numbers(self, async_client) -> None: + async def test_stats_only_reports_numbers(self, async_client: Any) -> None: docs = [{"answer": x} for x in range(100)] success, failed = await actions.async_bulk( async_client, docs, index="test-index", refresh=True, stats_only=True @@ -316,7 +316,7 @@ async def test_stats_only_reports_numbers(self, async_client) -> None: assert 0 == failed assert 100 == (await async_client.count(index="test-index"))["count"] - async def test_errors_are_reported_correctly(self, async_client): + async def test_errors_are_reported_correctly(self, async_client: Any) -> None: await async_client.indices.create( "i", { @@ -333,6 +333,7 @@ async def test_errors_are_reported_correctly(self, async_client): raise_on_error=False, ) assert 1 == success + assert isinstance(failed, List) assert 1 == len(failed) error = failed[0] assert "42" == error["index"]["_id"] @@ -342,7 +343,7 @@ async def test_errors_are_reported_correctly(self, async_client): error["index"]["error"] ) or "mapper_parsing_exception" in repr(error["index"]["error"]) - async def test_error_is_raised(self, async_client): + async def test_error_is_raised(self, async_client: Any) -> None: await async_client.indices.create( "i", { @@ -355,7 +356,7 @@ async def test_error_is_raised(self, async_client): with pytest.raises(BulkIndexError): await actions.async_bulk(async_client, [{"a": 42}, {"a": "c"}], index="i") - async def test_ignore_error_if_raised(self, async_client): + async def test_ignore_error_if_raised(self, async_client: Any) -> None: # ignore the status code 400 in tuple await actions.async_bulk( async_client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=(400,) @@ -388,7 +389,7 @@ async def test_ignore_error_if_raised(self, async_client): failing_client, [{"a": 42}], index="i", ignore_status=(599,) ) - async def test_errors_are_collected_properly(self, async_client): + async def test_errors_are_collected_properly(self, async_client: Any) -> None: await async_client.indices.create( "i", { @@ -410,10 +411,12 @@ async def test_errors_are_collected_properly(self, async_client): class MockScroll: + calls: Any + def __init__(self) -> None: self.calls = [] - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: Any, **kwargs: Any) -> Any: self.calls.append((args, kwargs)) if len(self.calls) == 1: return { @@ -432,25 +435,27 @@ async def __call__(self, *args, **kwargs): class MockResponse: - def __init__(self, resp) -> None: + def __init__(self, resp: Any) -> None: self.resp = resp - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.resp - def __await__(self): + def __await__(self) -> Any: return self().__await__() -@pytest.fixture(scope="function") -async def scan_teardown(async_client): +@pytest.fixture(scope="function") # type: ignore +async def scan_teardown(async_client: Any) -> Any: yield await async_client.clear_scroll(scroll_id="_all") class TestScan(object): - async def test_order_can_be_preserved(self, async_client, scan_teardown): - bulk = [] + async def test_order_can_be_preserved( + self, async_client: Any, scan_teardown: Any + ) -> None: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append({"answer": x, "correct": x == 42}) @@ -470,8 +475,10 @@ async def test_order_can_be_preserved(self, async_client, scan_teardown): assert list(map(str, range(100))) == list(d["_id"] for d in docs) assert list(range(100)) == list(d["_source"]["answer"] for d in docs) - async def test_all_documents_are_read(self, async_client, scan_teardown): - bulk = [] + async def test_all_documents_are_read( + self, async_client: Any, scan_teardown: Any + ) -> None: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append({"answer": x, "correct": x == 42}) @@ -486,8 +493,8 @@ async def test_all_documents_are_read(self, async_client, scan_teardown): assert set(map(str, range(100))) == set(d["_id"] for d in docs) assert set(range(100)) == set(d["_source"]["answer"] for d in docs) - async def test_scroll_error(self, async_client, scan_teardown): - bulk = [] + async def test_scroll_error(self, async_client: Any, scan_teardown: Any) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -522,7 +529,9 @@ async def test_scroll_error(self, async_client, scan_teardown): assert len(data) == 3 assert data[-1] == {"scroll_data": 42} - async def test_initial_search_error(self, async_client, scan_teardown): + async def test_initial_search_error( + self, async_client: Any, scan_teardown: Any + ) -> None: with patch.object(async_client, "clear_scroll", new_callable=AsyncMock): with patch.object( async_client, @@ -572,7 +581,9 @@ async def test_initial_search_error(self, async_client, scan_teardown): assert data == [{"search_data": 1}] assert mock_scroll.calls == [] - async def test_no_scroll_id_fast_route(self, async_client, scan_teardown) -> None: + async def test_no_scroll_id_fast_route( + self, async_client: Any, scan_teardown: Any + ) -> None: with patch.object(async_client, "search", MockResponse({"no": "_scroll_id"})): with patch.object(async_client, "scroll") as scroll_mock: with patch.object(async_client, "clear_scroll") as clear_mock: @@ -588,8 +599,10 @@ async def test_no_scroll_id_fast_route(self, async_client, scan_teardown) -> Non clear_mock.assert_not_called() @patch("opensearchpy._async.helpers.actions.logger") - async def test_logger(self, logger_mock, async_client, scan_teardown): - bulk = [] + async def test_logger( + self, logger_mock: Any, async_client: Any, scan_teardown: Any + ) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -629,8 +642,8 @@ async def test_logger(self, logger_mock, async_client, scan_teardown): 5, ) - async def test_clear_scroll(self, async_client, scan_teardown): - bulk = [] + async def test_clear_scroll(self, async_client: Any, scan_teardown: Any) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -672,10 +685,10 @@ async def test_clear_scroll(self, async_client, scan_teardown): {"http_auth": ("username", "password")}, {"headers": {"custom", "header"}}, ], - ) + ) # type: ignore async def test_scan_auth_kwargs_forwarded( - self, async_client, scan_teardown, kwargs - ): + self, async_client: Any, scan_teardown: Any, kwargs: Any + ) -> None: ((key, val),) = kwargs.items() with patch.object( @@ -716,8 +729,8 @@ async def test_scan_auth_kwargs_forwarded( assert api_mock.call_args[1][key] == val async def test_scan_auth_kwargs_favor_scroll_kwargs_option( - self, async_client, scan_teardown - ): + self, async_client: Any, scan_teardown: Any + ) -> None: with patch.object( async_client, "search", @@ -765,9 +778,9 @@ async def test_scan_auth_kwargs_favor_scroll_kwargs_option( assert async_client.scroll.call_args[1]["sort"] == "asc" -@pytest.fixture(scope="function") -async def reindex_setup(async_client): - bulk = [] +@pytest.fixture(scope="function") # type: ignore +async def reindex_setup(async_client: Any) -> Any: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append( @@ -783,7 +796,7 @@ async def reindex_setup(async_client): class TestReindex(object): async def test_reindex_passes_kwargs_to_scan_and_bulk( - self, async_client, reindex_setup + self, async_client: Any, reindex_setup: Any ) -> None: await actions.async_reindex( async_client, @@ -803,7 +816,9 @@ async def test_reindex_passes_kwargs_to_scan_and_bulk( await async_client.get(index="prod_index", id=42) )["_source"] - async def test_reindex_accepts_a_query(self, async_client, reindex_setup) -> None: + async def test_reindex_accepts_a_query( + self, async_client: Any, reindex_setup: Any + ) -> None: await actions.async_reindex( async_client, "test_index", @@ -822,7 +837,9 @@ async def test_reindex_accepts_a_query(self, async_client, reindex_setup) -> Non await async_client.get(index="prod_index", id=42) )["_source"] - async def test_all_documents_get_moved(self, async_client, reindex_setup) -> None: + async def test_all_documents_get_moved( + self, async_client: Any, reindex_setup: Any + ) -> None: await actions.async_reindex(async_client, "test_index", "prod_index") await async_client.indices.refresh() @@ -843,8 +860,8 @@ async def test_all_documents_get_moved(self, async_client, reindex_setup) -> Non )["_source"] -@pytest.fixture(scope="function") -async def parent_reindex_setup(async_client): +@pytest.fixture(scope="function") # type: ignore +async def parent_reindex_setup(async_client: Any) -> None: body = { "settings": {"number_of_shards": 1, "number_of_replicas": 0}, "mappings": { @@ -873,8 +890,8 @@ async def parent_reindex_setup(async_client): class TestParentChildReindex: async def test_children_are_reindexed_correctly( - self, async_client, parent_reindex_setup - ): + self, async_client: Any, parent_reindex_setup: Any + ) -> None: await actions.async_reindex(async_client, "test-index", "real-index") assert {"question_answer": "question"} == ( await async_client.get(index="real-index", id=42) diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_data.py b/test_opensearchpy/test_async/test_server/test_helpers/test_data.py index 99f2486d..7a23b8b1 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_data.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_data.py @@ -13,7 +13,7 @@ from typing import Any, Dict -async def create_flat_git_index(client, index): +async def create_flat_git_index(client: Any, index: Any) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -56,7 +56,7 @@ async def create_flat_git_index(client, index): ) -async def create_git_index(client, index): +async def create_git_index(client: Any, index: Any) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -1078,7 +1078,7 @@ async def create_git_index(client, index): ] -def flatten_doc(d) -> Dict[str, Any]: +def flatten_doc(d: Any) -> Dict[str, Any]: src = d["_source"].copy() del src["commit_repo"] return {"_index": "flat-git", "_id": d["_id"], "_source": src} @@ -1087,7 +1087,7 @@ def flatten_doc(d) -> Dict[str, Any]: FLAT_DATA = [flatten_doc(d) for d in DATA if "routing" in d] -def create_test_git_data(d) -> Dict[str, Any]: +def create_test_git_data(d: Any) -> Dict[str, Any]: src = d["_source"].copy() return { "_index": "test-git", diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_document.py b/test_opensearchpy/test_async/test_server/test_helpers/test_document.py index 67982918..8e4e95e2 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_document.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_document.py @@ -10,6 +10,7 @@ from datetime import datetime from ipaddress import ip_address +from typing import Any, Optional import pytest from pytest import raises @@ -63,7 +64,7 @@ class Repository(AsyncDocument): tags = Keyword() @classmethod - def search(cls): + def search(cls, using: Any = None, index: Optional[str] = None) -> Any: return super(Repository, cls).search().filter("term", commit_repo="repo") class Index: @@ -116,7 +117,7 @@ class Index: name = "test-serialization" -async def test_serialization(write_client): +async def test_serialization(write_client: Any) -> None: await SerializationDoc.init() await write_client.index( index="test-serialization", @@ -129,7 +130,7 @@ async def test_serialization(write_client): "ip": ["::1", "127.0.0.1", None], }, ) - sd = await SerializationDoc.get(id=42) + sd: Any = await SerializationDoc.get(id=42) assert sd.i == [1, 2, 3, None] assert sd.b == [True, False, True, False, None] @@ -146,7 +147,7 @@ async def test_serialization(write_client): } -async def test_nested_inner_hits_are_wrapped_properly(pull_request) -> None: +async def test_nested_inner_hits_are_wrapped_properly(pull_request: Any) -> None: history_query = Q( "nested", path="comments.history", @@ -174,7 +175,7 @@ async def test_nested_inner_hits_are_wrapped_properly(pull_request) -> None: assert "score" in history.meta -async def test_nested_inner_hits_are_deserialized_properly(pull_request) -> None: +async def test_nested_inner_hits_are_deserialized_properly(pull_request: Any) -> None: s = PullRequest.search().query( "nested", inner_hits={}, @@ -189,7 +190,7 @@ async def test_nested_inner_hits_are_deserialized_properly(pull_request) -> None assert isinstance(pr.comments[0].created_at, datetime) -async def test_nested_top_hits_are_wrapped_properly(pull_request) -> None: +async def test_nested_top_hits_are_wrapped_properly(pull_request: Any) -> None: s = PullRequest.search() s.aggs.bucket("comments", "nested", path="comments").metric( "hits", "top_hits", size=1 @@ -201,7 +202,7 @@ async def test_nested_top_hits_are_wrapped_properly(pull_request) -> None: assert isinstance(r.aggregations.comments.hits.hits[0], Comment) -async def test_update_object_field(write_client) -> None: +async def test_update_object_field(write_client: Any) -> None: await Wiki.init() w = Wiki( owner=User(name="Honza Kral"), @@ -221,7 +222,7 @@ async def test_update_object_field(write_client) -> None: assert w.ranked == {"test1": 0.1, "topic2": 0.2} -async def test_update_script(write_client) -> None: +async def test_update_script(write_client: Any) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) await w.save() @@ -231,7 +232,7 @@ async def test_update_script(write_client) -> None: assert w.views == 47 -async def test_update_retry_on_conflict(write_client) -> None: +async def test_update_retry_on_conflict(write_client: Any) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) await w.save() @@ -249,8 +250,10 @@ async def test_update_retry_on_conflict(write_client) -> None: assert w.views == 52 -@pytest.mark.parametrize("retry_on_conflict", [None, 0]) -async def test_update_conflicting_version(write_client, retry_on_conflict) -> None: +@pytest.mark.parametrize("retry_on_conflict", [None, 0]) # type: ignore +async def test_update_conflicting_version( + write_client: Any, retry_on_conflict: bool +) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) await w.save() @@ -267,7 +270,7 @@ async def test_update_conflicting_version(write_client, retry_on_conflict) -> No ) -async def test_save_and_update_return_doc_meta(write_client) -> None: +async def test_save_and_update_return_doc_meta(write_client: Any) -> None: await Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) resp = await w.save(return_doc_meta=True) @@ -291,33 +294,33 @@ async def test_save_and_update_return_doc_meta(write_client) -> None: assert resp.keys().__contains__("_version") -async def test_init(write_client) -> None: +async def test_init(write_client: Any) -> None: await Repository.init(index="test-git") assert await write_client.indices.exists(index="test-git") -async def test_get_raises_404_on_index_missing(data_client) -> None: +async def test_get_raises_404_on_index_missing(data_client: Any) -> None: with raises(NotFoundError): await Repository.get("opensearch-dsl-php", index="not-there") -async def test_get_raises_404_on_non_existent_id(data_client) -> None: +async def test_get_raises_404_on_non_existent_id(data_client: Any) -> None: with raises(NotFoundError): await Repository.get("opensearch-dsl-php") -async def test_get_returns_none_if_404_ignored(data_client) -> None: +async def test_get_returns_none_if_404_ignored(data_client: Any) -> None: assert None is await Repository.get("opensearch-dsl-php", ignore=404) async def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( - data_client, + data_client: Any, ) -> None: assert None is await Repository.get("42", index="not-there", ignore=404) -async def test_get(data_client) -> None: +async def test_get(data_client: Any) -> None: opensearch_repo = await Repository.get("opensearch-py") assert isinstance(opensearch_repo, Repository) @@ -325,15 +328,15 @@ async def test_get(data_client) -> None: assert datetime(2014, 3, 3) == opensearch_repo.created_at -async def test_exists_return_true(data_client) -> None: +async def test_exists_return_true(data_client: Any) -> None: assert await Repository.exists("opensearch-py") -async def test_exists_false(data_client) -> None: +async def test_exists_false(data_client: Any) -> None: assert not await Repository.exists("opensearch-dsl-php") -async def test_get_with_tz_date(data_client) -> None: +async def test_get_with_tz_date(data_client: Any) -> None: first_commit = await Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="opensearch-py" ) @@ -345,7 +348,7 @@ async def test_get_with_tz_date(data_client) -> None: ) -async def test_save_with_tz_date(data_client) -> None: +async def test_save_with_tz_date(data_client: Any) -> None: tzinfo = timezone("Europe/Prague") first_commit = await Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="opensearch-py" @@ -372,7 +375,7 @@ async def test_save_with_tz_date(data_client) -> None: ] -async def test_mget(data_client) -> None: +async def test_mget(data_client: Any) -> None: commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING) assert commits[0] is None assert commits[1].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" @@ -380,25 +383,27 @@ async def test_mget(data_client) -> None: assert commits[3].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" -async def test_mget_raises_exception_when_missing_param_is_invalid(data_client) -> None: +async def test_mget_raises_exception_when_missing_param_is_invalid( + data_client: Any, +) -> None: with raises(ValueError): await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raj") -async def test_mget_raises_404_when_missing_param_is_raise(data_client) -> None: +async def test_mget_raises_404_when_missing_param_is_raise(data_client: Any) -> None: with raises(NotFoundError): await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raise") async def test_mget_ignores_missing_docs_when_missing_param_is_skip( - data_client, + data_client: Any, ) -> None: commits = await Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="skip") assert commits[0].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" assert commits[1].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" -async def test_update_works_from_search_response(data_client) -> None: +async def test_update_works_from_search_response(data_client: Any) -> None: opensearch_repo = (await Repository.search().execute())[0] await opensearch_repo.update(owner={"other_name": "opensearchpy"}) @@ -409,7 +414,7 @@ async def test_update_works_from_search_response(data_client) -> None: assert "opensearch" == new_version.owner.name -async def test_update(data_client) -> None: +async def test_update(data_client: Any) -> None: opensearch_repo = await Repository.get("opensearch-py") v = opensearch_repo.meta.version @@ -433,7 +438,7 @@ async def test_update(data_client) -> None: assert "primary_term" in new_version.meta -async def test_save_updates_existing_doc(data_client) -> None: +async def test_save_updates_existing_doc(data_client: Any) -> None: opensearch_repo = await Repository.get("opensearch-py") opensearch_repo.new_field = "testing-save" @@ -446,7 +451,9 @@ async def test_save_updates_existing_doc(data_client) -> None: assert new_repo["_seq_no"] == opensearch_repo.meta.seq_no -async def test_save_automatically_uses_seq_no_and_primary_term(data_client) -> None: +async def test_save_automatically_uses_seq_no_and_primary_term( + data_client: Any, +) -> None: opensearch_repo = await Repository.get("opensearch-py") opensearch_repo.meta.seq_no += 1 @@ -454,7 +461,9 @@ async def test_save_automatically_uses_seq_no_and_primary_term(data_client) -> N await opensearch_repo.save() -async def test_delete_automatically_uses_seq_no_and_primary_term(data_client) -> None: +async def test_delete_automatically_uses_seq_no_and_primary_term( + data_client: Any, +) -> None: opensearch_repo = await Repository.get("opensearch-py") opensearch_repo.meta.seq_no += 1 @@ -462,13 +471,13 @@ async def test_delete_automatically_uses_seq_no_and_primary_term(data_client) -> await opensearch_repo.delete() -async def assert_doc_equals(expected, actual) -> None: +async def assert_doc_equals(expected: Any, actual: Any) -> None: async for f in aiter(expected): assert f in actual assert actual[f] == expected[f] -async def test_can_save_to_different_index(write_client): +async def test_can_save_to_different_index(write_client: Any) -> None: test_repo = Repository(description="testing", meta={"id": 42}) assert await test_repo.save(index="test-document") @@ -483,7 +492,9 @@ async def test_can_save_to_different_index(write_client): ) -async def test_save_without_skip_empty_will_include_empty_fields(write_client) -> None: +async def test_save_without_skip_empty_will_include_empty_fields( + write_client: Any, +) -> None: test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) assert await test_repo.save(index="test-document", skip_empty=False) @@ -498,7 +509,7 @@ async def test_save_without_skip_empty_will_include_empty_fields(write_client) - ) -async def test_delete(write_client) -> None: +async def test_delete(write_client: Any) -> None: await write_client.create( index="test-document", id="opensearch-py", @@ -519,11 +530,11 @@ async def test_delete(write_client) -> None: ) -async def test_search(data_client) -> None: +async def test_search(data_client: Any) -> None: assert await Repository.search().count() == 1 -async def test_search_returns_proper_doc_classes(data_client) -> None: +async def test_search_returns_proper_doc_classes(data_client: Any) -> None: result = await Repository.search().execute() opensearch_repo = result.hits[0] @@ -532,8 +543,10 @@ async def test_search_returns_proper_doc_classes(data_client) -> None: assert opensearch_repo.owner.name == "opensearch" -async def test_refresh_mapping(data_client) -> None: +async def test_refresh_mapping(data_client: Any) -> None: class Commit(AsyncDocument): + _index: Any + class Index: name = "git" @@ -546,7 +559,7 @@ class Index: assert isinstance(Commit._index._mapping["committed_date"], Date) -async def test_highlight_in_meta(data_client) -> None: +async def test_highlight_in_meta(data_client: Any) -> None: commit = ( await Commit.search() .query("match", description="inverting") diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_faceted_search.py b/test_opensearchpy/test_async/test_server/test_helpers/test_faceted_search.py index bc7abbd8..b03fefe8 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_faceted_search.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_faceted_search.py @@ -9,6 +9,7 @@ # GitHub history for details. from datetime import datetime +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator @@ -54,8 +55,8 @@ class MetricSearch(AsyncFacetedSearch): } -@pytest.fixture(scope="function") -def commit_search_cls(opensearch_version): +@pytest.fixture(scope="function") # type: ignore +def commit_search_cls(opensearch_version: Any) -> Any: interval_kwargs = {"fixed_interval": "1d"} class CommitSearch(AsyncFacetedSearch): @@ -79,8 +80,8 @@ class CommitSearch(AsyncFacetedSearch): return CommitSearch -@pytest.fixture(scope="function") -def repo_search_cls(opensearch_version): +@pytest.fixture(scope="function") # type: ignore +def repo_search_cls(opensearch_version: Any) -> Any: interval_type = "calendar_interval" class RepoSearch(AsyncFacetedSearch): @@ -93,15 +94,15 @@ class RepoSearch(AsyncFacetedSearch): ), } - def search(self): + def search(self) -> Any: s = super(RepoSearch, self).search() return s.filter("term", commit_repo="repo") return RepoSearch -@pytest.fixture(scope="function") -def pr_search_cls(opensearch_version): +@pytest.fixture(scope="function") # type: ignore +def pr_search_cls(opensearch_version: Any) -> Any: interval_type = "calendar_interval" class PRSearch(AsyncFacetedSearch): @@ -119,7 +120,7 @@ class PRSearch(AsyncFacetedSearch): return PRSearch -async def test_facet_with_custom_metric(data_client) -> None: +async def test_facet_with_custom_metric(data_client: Any) -> None: ms = MetricSearch() r = await ms.execute() @@ -128,7 +129,7 @@ async def test_facet_with_custom_metric(data_client) -> None: assert dates[0] == 1399038439000 -async def test_nested_facet(pull_request, pr_search_cls) -> None: +async def test_nested_facet(pull_request: Any, pr_search_cls: Any) -> None: prs = pr_search_cls() r = await prs.execute() @@ -136,7 +137,7 @@ async def test_nested_facet(pull_request, pr_search_cls) -> None: assert [(datetime(2018, 1, 1, 0, 0), 1, False)] == r.facets.comments -async def test_nested_facet_with_filter(pull_request, pr_search_cls) -> None: +async def test_nested_facet_with_filter(pull_request: Any, pr_search_cls: Any) -> None: prs = pr_search_cls(filters={"comments": datetime(2018, 1, 1, 0, 0)}) r = await prs.execute() @@ -148,7 +149,7 @@ async def test_nested_facet_with_filter(pull_request, pr_search_cls) -> None: assert not r.hits -async def test_datehistogram_facet(data_client, repo_search_cls) -> None: +async def test_datehistogram_facet(data_client: Any, repo_search_cls: Any) -> None: rs = repo_search_cls() r = await rs.execute() @@ -156,7 +157,7 @@ async def test_datehistogram_facet(data_client, repo_search_cls) -> None: assert [(datetime(2014, 3, 1, 0, 0), 1, False)] == r.facets.created -async def test_boolean_facet(data_client, repo_search_cls) -> None: +async def test_boolean_facet(data_client: Any, repo_search_cls: Any) -> None: rs = repo_search_cls() r = await rs.execute() @@ -167,7 +168,7 @@ async def test_boolean_facet(data_client, repo_search_cls) -> None: async def test_empty_search_finds_everything( - data_client, opensearch_version, commit_search_cls + data_client: Any, opensearch_version: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls() r = await cs.execute() @@ -213,7 +214,7 @@ async def test_empty_search_finds_everything( async def test_term_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls + data_client: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls(filters={"files": "test_opensearchpy/test_dsl"}) @@ -259,7 +260,7 @@ async def test_term_filters_are_shown_as_selected_and_data_is_filtered( async def test_range_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls + data_client: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls(filters={"deletions": "better"}) @@ -268,7 +269,7 @@ async def test_range_filters_are_shown_as_selected_and_data_is_filtered( assert 19 == r.hits.total.value -async def test_pagination(data_client, commit_search_cls) -> None: +async def test_pagination(data_client: Any, commit_search_cls: Any) -> None: cs = commit_search_cls() cs = cs[0:20] diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_index.py b/test_opensearchpy/test_async/test_server/test_helpers/test_index.py index f11e6d3f..14b87e15 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_index.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_index.py @@ -8,6 +8,8 @@ # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator @@ -24,7 +26,7 @@ class Post(AsyncDocument): published_from = Date() -async def test_index_template_works(write_client) -> None: +async def test_index_template_works(write_client: Any) -> None: it = AsyncIndexTemplate("test-template", "test-*") it.document(Post) it.settings(number_of_replicas=0, number_of_shards=1) @@ -45,7 +47,7 @@ async def test_index_template_works(write_client) -> None: } == await write_client.indices.get_mapping(index="test-blog") -async def test_index_can_be_saved_even_with_settings(write_client) -> None: +async def test_index_can_be_saved_even_with_settings(write_client: Any) -> None: i = AsyncIndex("test-blog", using=write_client) i.settings(number_of_shards=3, number_of_replicas=0) await i.save() @@ -60,12 +62,14 @@ async def test_index_can_be_saved_even_with_settings(write_client) -> None: ) -async def test_index_exists(data_client) -> None: +async def test_index_exists(data_client: Any) -> None: assert await AsyncIndex("git").exists() assert not await AsyncIndex("not-there").exists() -async def test_index_can_be_created_with_settings_and_mappings(write_client) -> None: +async def test_index_can_be_created_with_settings_and_mappings( + write_client: Any, +) -> None: i = AsyncIndex("test-blog", using=write_client) i.document(Post) i.settings(number_of_replicas=0, number_of_shards=1) @@ -90,7 +94,7 @@ async def test_index_can_be_created_with_settings_and_mappings(write_client) -> } -async def test_delete(write_client) -> None: +async def test_delete(write_client: Any) -> None: await write_client.indices.create( index="test-index", body={"settings": {"number_of_replicas": 0, "number_of_shards": 1}}, @@ -101,9 +105,9 @@ async def test_delete(write_client) -> None: assert not await write_client.indices.exists(index="test-index") -async def test_multiple_indices_with_same_doc_type_work(write_client) -> None: - i1 = AsyncIndex("test-index-1", using=write_client) - i2 = AsyncIndex("test-index-2", using=write_client) +async def test_multiple_indices_with_same_doc_type_work(write_client: Any) -> None: + i1: Any = AsyncIndex("test-index-1", using=write_client) + i2: Any = AsyncIndex("test-index-2", using=write_client) for i in i1, i2: i.document(Post) diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_mapping.py b/test_opensearchpy/test_async/test_server/test_helpers/test_mapping.py index 6be391b3..35a4e8d8 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_mapping.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_mapping.py @@ -8,6 +8,8 @@ # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator from pytest import raises @@ -19,7 +21,7 @@ pytestmark: MarkDecorator = pytest.mark.asyncio -async def test_mapping_saved_into_opensearch(write_client) -> None: +async def test_mapping_saved_into_opensearch(write_client: Any) -> None: m = mapping.AsyncMapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -40,7 +42,7 @@ async def test_mapping_saved_into_opensearch(write_client) -> None: async def test_mapping_saved_into_opensearch_when_index_already_exists_closed( - write_client, + write_client: Any, ) -> None: m = mapping.AsyncMapping() m.field( @@ -65,7 +67,7 @@ async def test_mapping_saved_into_opensearch_when_index_already_exists_closed( async def test_mapping_saved_into_opensearch_when_index_already_exists_with_analysis( - write_client, + write_client: Any, ) -> None: m = mapping.AsyncMapping() analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -95,7 +97,7 @@ async def test_mapping_saved_into_opensearch_when_index_already_exists_with_anal } == await write_client.indices.get_mapping(index="test-mapping") -async def test_mapping_gets_updated_from_opensearch(write_client): +async def test_mapping_gets_updated_from_opensearch(write_client: Any) -> None: await write_client.indices.create( index="test-mapping", body={ diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_search.py b/test_opensearchpy/test_async/test_server/test_helpers/test_search.py index 2b995c54..8431fa4a 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_search.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_search.py @@ -10,6 +10,8 @@ from __future__ import unicode_literals +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator from pytest import raises @@ -29,7 +31,7 @@ class Repository(AsyncDocument): tags = Keyword() @classmethod - def search(cls): + def search(cls, using: Any = None, index: Any = None) -> Any: return super(Repository, cls).search().filter("term", commit_repo="repo") class Index: @@ -41,7 +43,7 @@ class Index: name = "flat-git" -async def test_filters_aggregation_buckets_are_accessible(data_client) -> None: +async def test_filters_aggregation_buckets_are_accessible(data_client: Any) -> None: has_tests_query = Q("term", files="test_opensearchpy/test_dsl") s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( @@ -62,7 +64,7 @@ async def test_filters_aggregation_buckets_are_accessible(data_client) -> None: ) -async def test_top_hits_are_wrapped_in_response(data_client) -> None: +async def test_top_hits_are_wrapped_in_response(data_client: Any) -> None: s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").metric( "top_commits", "top_hits", size=5 @@ -78,7 +80,7 @@ async def test_top_hits_are_wrapped_in_response(data_client) -> None: assert isinstance(hits[0], Commit) -async def test_inner_hits_are_wrapped_in_response(data_client) -> None: +async def test_inner_hits_are_wrapped_in_response(data_client: Any) -> None: s = AsyncSearch(index="git")[0:1].query( "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") ) @@ -89,7 +91,7 @@ async def test_inner_hits_are_wrapped_in_response(data_client) -> None: assert repr(commit.meta.inner_hits.repo[0]).startswith(" None: +async def test_scan_respects_doc_types(data_client: Any) -> None: result = Repository.search().scan() repos = await get_result(result) @@ -98,7 +100,7 @@ async def test_scan_respects_doc_types(data_client) -> None: assert repos[0].organization == "opensearch" -async def test_scan_iterates_through_all_docs(data_client) -> None: +async def test_scan_iterates_through_all_docs(data_client: Any) -> None: s = AsyncSearch(index="flat-git") result = s.scan() commits = await get_result(result) @@ -107,14 +109,14 @@ async def test_scan_iterates_through_all_docs(data_client) -> None: assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} -async def get_result(b): +async def get_result(b: Any) -> Any: a = [] async for i in b: a.append(i) return a -async def test_multi_search(data_client) -> None: +async def test_multi_search(data_client: Any) -> None: s1 = Repository.search() s2 = AsyncSearch(index="flat-git") @@ -131,7 +133,7 @@ async def test_multi_search(data_client) -> None: assert r2._search is s2 -async def test_multi_missing(data_client) -> None: +async def test_multi_missing(data_client: Any) -> None: s1 = Repository.search() s2 = AsyncSearch(index="flat-git") s3 = AsyncSearch(index="does_not_exist") @@ -154,7 +156,7 @@ async def test_multi_missing(data_client) -> None: assert r3 is None -async def test_raw_subfield_can_be_used_in_aggs(data_client) -> None: +async def test_raw_subfield_can_be_used_in_aggs(data_client: Any) -> None: s = AsyncSearch(index="git")[0:0] s.aggs.bucket("authors", "terms", field="author.name.raw", size=1) r = await s.execute() diff --git a/test_opensearchpy/test_async/test_server/test_helpers/test_update_by_query.py b/test_opensearchpy/test_async/test_server/test_helpers/test_update_by_query.py index 4dcf32b3..46e515df 100644 --- a/test_opensearchpy/test_async/test_server/test_helpers/test_update_by_query.py +++ b/test_opensearchpy/test_async/test_server/test_helpers/test_update_by_query.py @@ -8,6 +8,8 @@ # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. +from typing import Any + import pytest from _pytest.mark.structures import MarkDecorator @@ -17,7 +19,9 @@ pytestmark: MarkDecorator = pytest.mark.asyncio -async def test_update_by_query_no_script(write_client, setup_ubq_tests) -> None: +async def test_update_by_query_no_script( + write_client: Any, setup_ubq_tests: Any +) -> None: index = setup_ubq_tests ubq = ( @@ -36,7 +40,9 @@ async def test_update_by_query_no_script(write_client, setup_ubq_tests) -> None: assert response.success() -async def test_update_by_query_with_script(write_client, setup_ubq_tests) -> None: +async def test_update_by_query_with_script( + write_client: Any, setup_ubq_tests: Any +) -> None: index = setup_ubq_tests ubq = ( @@ -53,7 +59,9 @@ async def test_update_by_query_with_script(write_client, setup_ubq_tests) -> Non assert response.version_conflicts == 0 -async def test_delete_by_query_with_script(write_client, setup_ubq_tests) -> None: +async def test_delete_by_query_with_script( + write_client: Any, setup_ubq_tests: Any +) -> None: index = setup_ubq_tests ubq = ( diff --git a/test_opensearchpy/test_async/test_server/test_plugins/test_alerting.py b/test_opensearchpy/test_async/test_server/test_plugins/test_alerting.py index 88b792db..09c0bc1e 100644 --- a/test_opensearchpy/test_async/test_server/test_plugins/test_alerting.py +++ b/test_opensearchpy/test_async/test_server/test_plugins/test_alerting.py @@ -16,7 +16,7 @@ import pytest from _pytest.mark.structures import MarkDecorator -from opensearchpy.helpers.test import OPENSEARCH_VERSION +from opensearchpy.helpers.test import OPENSEARCH_VERSION # type: ignore from .. import AsyncOpenSearchTestCase @@ -28,7 +28,7 @@ class TestAlertingPlugin(AsyncOpenSearchTestCase): (OPENSEARCH_VERSION) and (OPENSEARCH_VERSION < (2, 0, 0)), "Plugin not supported for opensearch version", ) - async def test_create_destination(self): + async def test_create_destination(self) -> None: # Test to create alert destination dummy_destination = { "name": "my-destination", @@ -59,7 +59,7 @@ async def test_get_destination(self) -> None: (OPENSEARCH_VERSION) and (OPENSEARCH_VERSION < (2, 0, 0)), "Plugin not supported for opensearch version", ) - async def test_create_monitor(self): + async def test_create_monitor(self) -> None: # Create a dummy destination await self.test_create_destination() diff --git a/test_opensearchpy/test_async/test_server/test_rest_api_spec.py b/test_opensearchpy/test_async/test_server/test_rest_api_spec.py index 0efcd25e..cf13a80f 100644 --- a/test_opensearchpy/test_async/test_server/test_rest_api_spec.py +++ b/test_opensearchpy/test_async/test_server/test_rest_api_spec.py @@ -33,12 +33,13 @@ """ import inspect import warnings +from typing import Any import pytest from _pytest.mark.structures import MarkDecorator from opensearchpy import OpenSearchWarning -from opensearchpy.helpers.test import _get_version +from opensearchpy.helpers.test import _get_version # type: ignore from ...test_server.test_rest_api_spec import ( IMPLEMENTED_FEATURES, @@ -53,14 +54,14 @@ OPENSEARCH_VERSION = None -async def await_if_coro(x): +async def await_if_coro(x: Any) -> Any: if inspect.iscoroutine(x): return await x return x class AsyncYamlRunner(YamlRunner): - async def setup(self): + async def setup(self) -> None: # Pull skips from individual tests to not do unnecessary setup. skip_code = [] for action in self._run_code: @@ -78,12 +79,12 @@ async def setup(self): if self._setup_code: await self.run_code(self._setup_code) - async def teardown(self) -> None: + async def teardown(self) -> Any: if self._teardown_code: self.section("teardown") await self.run_code(self._teardown_code) - async def opensearch_version(self): + async def opensearch_version(self) -> Any: global OPENSEARCH_VERSION if OPENSEARCH_VERSION is None: version_string = (await self.client.info())["version"]["number"] @@ -93,10 +94,10 @@ async def opensearch_version(self): OPENSEARCH_VERSION = tuple(int(v) if v.isdigit() else 999 for v in version) return OPENSEARCH_VERSION - def section(self, name) -> None: + def section(self, name: str) -> None: print(("=" * 10) + " " + name + " " + ("=" * 10)) - async def run(self) -> None: + async def run(self) -> Any: try: await self.setup() self.section("test") @@ -107,7 +108,7 @@ async def run(self) -> None: except Exception: pass - async def run_code(self, test) -> None: + async def run_code(self, test: Any) -> Any: """Execute an instruction based on its type.""" for action in test: assert len(action) == 1 @@ -119,7 +120,7 @@ async def run_code(self, test) -> None: else: raise RuntimeError("Invalid action type %r" % (action_type,)) - async def run_do(self, action) -> None: + async def run_do(self, action: Any) -> Any: api = self.client headers = action.pop("headers", None) catch = action.pop("catch", None) @@ -171,7 +172,7 @@ async def run_do(self, action) -> None: # Filter out warnings raised by other components. caught_warnings = [ - str(w.message) + str(w.message) # type: ignore for w in caught_warnings if w.category == OpenSearchWarning and str(w.message) not in allowed_warnings @@ -179,13 +180,13 @@ async def run_do(self, action) -> None: # Sorting removes the issue with order raised. We only care about # if all warnings are raised in the single API call. - if warn and sorted(warn) != sorted(caught_warnings): + if warn and sorted(warn) != sorted(caught_warnings): # type: ignore raise AssertionError( "Expected warnings not equal to actual warnings: expected=%r actual=%r" % (warn, caught_warnings) ) - async def run_skip(self, skip) -> None: + async def run_skip(self, skip: Any) -> Any: if "features" in skip: features = skip["features"] if not isinstance(features, (tuple, list)): @@ -205,19 +206,19 @@ async def run_skip(self, skip) -> None: if min_version <= (await self.opensearch_version()) <= max_version: pytest.skip(reason) - async def _feature_enabled(self, name) -> bool: + async def _feature_enabled(self, name: str) -> Any: return False -@pytest.fixture(scope="function") -def async_runner(async_client): +@pytest.fixture(scope="function") # type: ignore +def async_runner(async_client: Any) -> AsyncYamlRunner: return AsyncYamlRunner(async_client) if RUN_ASYNC_REST_API_TESTS: - @pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) - async def test_rest_api_spec(test_spec, async_runner) -> None: + @pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) # type: ignore + async def test_rest_api_spec(test_spec: Any, async_runner: Any) -> None: if test_spec.get("skip", False): pytest.skip("Manually skipped in 'SKIP_TESTS'") async_runner.use_spec(test_spec) diff --git a/test_opensearchpy/test_async/test_server_secured/test_security_plugin.py b/test_opensearchpy/test_async/test_server_secured/test_security_plugin.py index 9fe8d9d1..f8726152 100644 --- a/test_opensearchpy/test_async/test_server_secured/test_security_plugin.py +++ b/test_opensearchpy/test_async/test_server_secured/test_security_plugin.py @@ -123,7 +123,7 @@ async def test_create_user_with_body_param_empty(self) -> None: else: assert False - async def test_create_user_with_role(self): + async def test_create_user_with_role(self) -> None: await self.test_create_role() # Test to create user diff --git a/test_opensearchpy/test_async/test_signer.py b/test_opensearchpy/test_async/test_signer.py index 50d734bc..319340da 100644 --- a/test_opensearchpy/test_async/test_signer.py +++ b/test_opensearchpy/test_async/test_signer.py @@ -18,7 +18,7 @@ class TestAsyncSigner: - def mock_session(self): + def mock_session(self) -> Mock: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex @@ -37,7 +37,7 @@ async def test_aws_signer_async_as_http_auth(self) -> None: from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth auth = AWSV4SignerAsyncAuth(self.mock_session(), region) - headers = auth("GET", "http://localhost", {}, {}) + headers = auth("GET", "http://localhost") assert "Authorization" in headers assert "X-Amz-Date" in headers assert "X-Amz-Security-Token" in headers @@ -48,7 +48,7 @@ async def test_aws_signer_async_when_region_is_null(self) -> None: from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth with pytest.raises(ValueError) as e: - AWSV4SignerAsyncAuth(session, None) + AWSV4SignerAsyncAuth(session, None) # type: ignore assert str(e.value) == "Region cannot be empty" with pytest.raises(ValueError) as e: @@ -71,7 +71,7 @@ async def test_aws_signer_async_when_service_is_specified(self) -> None: from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service) - headers = auth("GET", "http://localhost", {}, {}) + headers = auth("GET", "http://localhost") assert "Authorization" in headers assert "X-Amz-Date" in headers assert "X-Amz-Security-Token" in headers @@ -79,7 +79,7 @@ async def test_aws_signer_async_when_service_is_specified(self) -> None: class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner): - def mock_session(self, disable_get_frozen: bool = True): + def mock_session(self, disable_get_frozen: bool = True) -> Mock: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex @@ -99,7 +99,7 @@ async def test_aws_signer_async_frozen_credentials_as_http_auth(self) -> None: mock_session = self.mock_session() auth = AWSV4SignerAsyncAuth(mock_session, region) - headers = auth("GET", "http://localhost", {}, {}) + headers = auth("GET", "http://localhost") assert "Authorization" in headers assert "X-Amz-Date" in headers assert "X-Amz-Security-Token" in headers diff --git a/test_opensearchpy/test_async/test_transport.py b/test_opensearchpy/test_async/test_transport.py index 4dabee05..4ef80707 100644 --- a/test_opensearchpy/test_async/test_transport.py +++ b/test_opensearchpy/test_async/test_transport.py @@ -45,16 +45,16 @@ class DummyConnection(Connection): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: self.exception = kwargs.pop("exception", None) self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}") self.headers = kwargs.pop("headers", {}) self.delay = kwargs.pop("delay", 0) - self.calls = [] + self.calls: Any = [] self.closed = False super(DummyConnection, self).__init__(**kwargs) - async def perform_request(self, *args, **kwargs) -> Any: + async def perform_request(self, *args: Any, **kwargs: Any) -> Any: if self.closed: raise RuntimeError("This connection is closed") if self.delay: @@ -123,15 +123,15 @@ async def close(self) -> None: class TestTransport: async def test_single_connection_uses_dummy_connection_pool(self) -> None: - t = AsyncTransport([{}]) - await t._async_call() - assert isinstance(t.connection_pool, DummyConnectionPool) - t = AsyncTransport([{"host": "localhost"}]) - await t._async_call() - assert isinstance(t.connection_pool, DummyConnectionPool) + t1: Any = AsyncTransport([{}]) + await t1._async_call() + assert isinstance(t1.connection_pool, DummyConnectionPool) + t2: Any = AsyncTransport([{"host": "localhost"}]) + await t2._async_call() + assert isinstance(t2.connection_pool, DummyConnectionPool) async def test_request_timeout_extracted_from_params_and_passed(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}], connection_class=DummyConnection) await t.perform_request("GET", "/", params={"request_timeout": 42}) assert 1 == len(t.get_connection().calls) @@ -143,7 +143,7 @@ async def test_request_timeout_extracted_from_params_and_passed(self) -> None: } == t.get_connection().calls[0][1] async def test_timeout_extracted_from_params_and_passed(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}], connection_class=DummyConnection) await t.perform_request("GET", "/", params={"timeout": 84}) assert 1 == len(t.get_connection().calls) @@ -154,8 +154,10 @@ async def test_timeout_extracted_from_params_and_passed(self) -> None: "headers": None, } == t.get_connection().calls[0][1] - async def test_opaque_id(self): - t = AsyncTransport([{}], opaque_id="app-1", connection_class=DummyConnection) + async def test_opaque_id(self) -> None: + t: Any = AsyncTransport( + [{}], opaque_id="app-1", connection_class=DummyConnection + ) await t.perform_request("GET", "/") assert 1 == len(t.get_connection().calls) @@ -176,8 +178,8 @@ async def test_opaque_id(self): "headers": {"x-opaque-id": "request-1"}, } == t.get_connection().calls[1][1] - async def test_request_with_custom_user_agent_header(self): - t = AsyncTransport([{}], connection_class=DummyConnection) + async def test_request_with_custom_user_agent_header(self) -> None: + t: Any = AsyncTransport([{}], connection_class=DummyConnection) await t.perform_request( "GET", "/", headers={"user-agent": "my-custom-value/1.2.3"} @@ -190,7 +192,7 @@ async def test_request_with_custom_user_agent_header(self): } == t.get_connection().calls[0][1] async def test_send_get_body_as_source(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{}], send_get_body_as="source", connection_class=DummyConnection ) @@ -199,7 +201,7 @@ async def test_send_get_body_as_source(self) -> None: assert ("GET", "/", {"source": "{}"}, None) == t.get_connection().calls[0][0] async def test_send_get_body_as_post(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{}], send_get_body_as="POST", connection_class=DummyConnection ) @@ -208,7 +210,7 @@ async def test_send_get_body_as_post(self) -> None: assert ("POST", "/", None, b"{}") == t.get_connection().calls[0][0] async def test_body_gets_encoded_into_bytes(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}], connection_class=DummyConnection) await t.perform_request("GET", "/", body="你好") assert 1 == len(t.get_connection().calls) @@ -220,7 +222,7 @@ async def test_body_gets_encoded_into_bytes(self) -> None: ) == t.get_connection().calls[0][0] async def test_body_bytes_get_passed_untouched(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}], connection_class=DummyConnection) body = b"\xe4\xbd\xa0\xe5\xa5\xbd" await t.perform_request("GET", "/", body=body) @@ -228,7 +230,7 @@ async def test_body_bytes_get_passed_untouched(self) -> None: assert ("GET", "/", None, body) == t.get_connection().calls[0][0] async def test_body_surrogates_replaced_encoded_into_bytes(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}], connection_class=DummyConnection) await t.perform_request("GET", "/", body="你好\uda6a") assert 1 == len(t.get_connection().calls) @@ -240,36 +242,36 @@ async def test_body_surrogates_replaced_encoded_into_bytes(self) -> None: ) == t.get_connection().calls[0][0] async def test_kwargs_passed_on_to_connections(self) -> None: - t = AsyncTransport([{"host": "google.com"}], port=123) + t: Any = AsyncTransport([{"host": "google.com"}], port=123) await t._async_call() assert 1 == len(t.connection_pool.connections) assert "http://google.com:123" == t.connection_pool.connections[0].host async def test_kwargs_passed_on_to_connection_pool(self) -> None: dt = object() - t = AsyncTransport([{}, {}], dead_timeout=dt) + t: Any = AsyncTransport([{}, {}], dead_timeout=dt) await t._async_call() assert dt is t.connection_pool.dead_timeout async def test_custom_connection_class(self) -> None: class MyConnection(object): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs - t = AsyncTransport([{}], connection_class=MyConnection) + t: Any = AsyncTransport([{}], connection_class=MyConnection) await t._async_call() assert 1 == len(t.connection_pool.connections) assert isinstance(t.connection_pool.connections[0], MyConnection) async def test_add_connection(self) -> None: - t = AsyncTransport([{}], randomize_hosts=False) + t: Any = AsyncTransport([{}], randomize_hosts=False) t.add_connection({"host": "google.com", "port": 1234}) assert 2 == len(t.connection_pool.connections) assert "http://google.com:1234" == t.connection_pool.connections[1].host async def test_request_will_fail_after_X_retries(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{"exception": ConnectionError("abandon ship")}], connection_class=DummyConnection, ) @@ -284,7 +286,7 @@ async def test_request_will_fail_after_X_retries(self) -> None: assert 4 == len(t.get_connection().calls) async def test_failed_connection_will_be_marked_as_dead(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{"exception": ConnectionError("abandon ship")}] * 2, connection_class=DummyConnection, ) @@ -302,7 +304,7 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success( self, ) -> None: for method in ("GET", "HEAD"): - t = AsyncTransport([{}, {}], connection_class=DummyConnection) + t: Any = AsyncTransport([{}, {}], connection_class=DummyConnection) await t._async_call() con1 = t.connection_pool.get_connection() con2 = t.connection_pool.get_connection() @@ -314,7 +316,9 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success( assert 1 == len(t.connection_pool.dead_count) async def test_sniff_will_use_seed_connections(self) -> None: - t = AsyncTransport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) + t: Any = AsyncTransport( + [{"data": CLUSTER_NODES}], connection_class=DummyConnection + ) await t._async_call() t.set_connections([{"data": "invalid"}]) @@ -323,7 +327,7 @@ async def test_sniff_will_use_seed_connections(self) -> None: assert "http://1.1.1.1:123" == t.get_connection().host async def test_sniff_on_start_fetches_and_uses_nodes_list(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True, @@ -335,7 +339,7 @@ async def test_sniff_on_start_fetches_and_uses_nodes_list(self) -> None: assert "http://1.1.1.1:123" == t.get_connection().host async def test_sniff_on_start_ignores_sniff_timeout(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True, @@ -349,7 +353,7 @@ async def test_sniff_on_start_ignores_sniff_timeout(self) -> None: ].calls[0] async def test_sniff_uses_sniff_timeout(self) -> None: - t = AsyncTransport( + t: Any = AsyncTransport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_timeout=42, @@ -361,8 +365,8 @@ async def test_sniff_uses_sniff_timeout(self) -> None: 0 ].calls[0] - async def test_sniff_reuses_connection_instances_if_possible(self): - t = AsyncTransport( + async def test_sniff_reuses_connection_instances_if_possible(self) -> None: + t: Any = AsyncTransport( [{"data": CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], connection_class=DummyConnection, randomize_hosts=False, @@ -375,8 +379,8 @@ async def test_sniff_reuses_connection_instances_if_possible(self): assert 1 == len(t.connection_pool.connections) assert connection is t.get_connection() - async def test_sniff_on_fail_triggers_sniffing_on_fail(self): - t = AsyncTransport( + async def test_sniff_on_fail_triggers_sniffing_on_fail(self) -> None: + t: Any = AsyncTransport( [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_connection_fail=True, @@ -398,9 +402,11 @@ async def test_sniff_on_fail_triggers_sniffing_on_fail(self): assert "http://1.1.1.1:123" == t.get_connection().host @patch("opensearchpy._async.transport.AsyncTransport.sniff_hosts") - async def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): + async def test_sniff_on_fail_failing_does_not_prevent_retires( + self, sniff_hosts: Any + ) -> None: sniff_hosts.side_effect = [TransportError("sniff failed")] - t = AsyncTransport( + t: Any = AsyncTransport( [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_connection_fail=True, @@ -416,8 +422,8 @@ async def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts) assert 1 == len(conn_err.calls) assert 1 == len(conn_data.calls) - async def test_sniff_after_n_seconds(self, event_loop) -> None: - t = AsyncTransport( + async def test_sniff_after_n_seconds(self, event_loop: Any) -> None: + t: Any = AsyncTransport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniffer_timeout=5, @@ -440,7 +446,7 @@ async def test_sniff_after_n_seconds(self, event_loop) -> None: async def test_sniff_7x_publish_host(self) -> None: # Test the response shaped when a 7.x node has publish_host set # and the returend data is shaped in the fqdn/ip:port format. - t = AsyncTransport( + t: Any = AsyncTransport( [{"data": CLUSTER_NODES_7x_PUBLISH_HOST}], connection_class=DummyConnection, sniff_timeout=42, @@ -454,22 +460,24 @@ async def test_sniff_7x_publish_host(self) -> None: } async def test_transport_close_closes_all_pool_connections(self) -> None: - t = AsyncTransport([{}], connection_class=DummyConnection) - await t._async_call() + t1: Any = AsyncTransport([{}], connection_class=DummyConnection) + await t1._async_call() - assert not any([conn.closed for conn in t.connection_pool.connections]) - await t.close() - assert all([conn.closed for conn in t.connection_pool.connections]) + assert not any([conn.closed for conn in t1.connection_pool.connections]) + await t1.close() + assert all([conn.closed for conn in t1.connection_pool.connections]) - t = AsyncTransport([{}, {}], connection_class=DummyConnection) - await t._async_call() + t2: Any = AsyncTransport([{}, {}], connection_class=DummyConnection) + await t2._async_call() - assert not any([conn.closed for conn in t.connection_pool.connections]) - await t.close() - assert all([conn.closed for conn in t.connection_pool.connections]) + assert not any([conn.closed for conn in t2.connection_pool.connections]) + await t2.close() + assert all([conn.closed for conn in t2.connection_pool.connections]) - async def test_sniff_on_start_error_if_no_sniffed_hosts(self, event_loop) -> None: - t = AsyncTransport( + async def test_sniff_on_start_error_if_no_sniffed_hosts( + self, event_loop: Any + ) -> None: + t: Any = AsyncTransport( [ {"data": ""}, {"data": ""}, @@ -485,8 +493,10 @@ async def test_sniff_on_start_error_if_no_sniffed_hosts(self, event_loop) -> Non await t._async_call() assert str(e.value) == "TransportError(N/A, 'Unable to sniff hosts.')" - async def test_sniff_on_start_waits_for_sniff_to_complete(self, event_loop): - t = AsyncTransport( + async def test_sniff_on_start_waits_for_sniff_to_complete( + self, event_loop: Any + ) -> None: + t: Any = AsyncTransport( [ {"delay": 1, "data": ""}, {"delay": 1, "data": ""}, @@ -521,8 +531,10 @@ async def test_sniff_on_start_waits_for_sniff_to_complete(self, event_loop): # and then resolved immediately after. assert 1 <= duration < 2 - async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): - t = AsyncTransport( + async def test_sniff_on_start_close_unlocks_async_calls( + self, event_loop: Any + ) -> None: + t: Any = AsyncTransport( [ {"delay": 10, "data": CLUSTER_NODES}, ], @@ -559,7 +571,7 @@ async def test_init_connection_pool_with_many_hosts(self) -> None: """ amt_hosts = 4 hosts = [{"host": "localhost", "port": 9092}] * amt_hosts - t = AsyncTransport( + t: Any = AsyncTransport( hosts=hosts, ) await t._async_init() @@ -577,7 +589,7 @@ async def test_init_pool_with_connection_class_to_many_hosts(self) -> None: """ amt_hosts = 4 hosts = [{"host": "localhost", "port": 9092}] * amt_hosts - t = AsyncTransport( + t: Any = AsyncTransport( hosts=hosts, connection_class=AIOHttpConnection, ) diff --git a/test_opensearchpy/test_cases.py b/test_opensearchpy/test_cases.py index ad795bcf..e36d9bb6 100644 --- a/test_opensearchpy/test_cases.py +++ b/test_opensearchpy/test_cases.py @@ -27,21 +27,30 @@ from collections import defaultdict -from unittest import SkipTest # noqa: F401 -from unittest import TestCase +from typing import Any, Sequence +from unittest import SkipTest, TestCase from opensearchpy import OpenSearch class DummyTransport(object): - def __init__(self, hosts, responses=None, **kwargs) -> None: + def __init__( + self, hosts: Sequence[str], responses: Any = None, **kwargs: Any + ) -> None: self.hosts = hosts self.responses = responses - self.call_count = 0 - self.calls = defaultdict(list) + self.call_count: int = 0 + self.calls: Any = defaultdict(list) - def perform_request(self, method, url, params=None, headers=None, body=None): - resp = 200, {} + def perform_request( + self, + method: str, + url: str, + params: Any = None, + headers: Any = None, + body: Any = None, + ) -> Any: + resp: Any = (200, {}) if self.responses: resp = self.responses[self.call_count] self.call_count += 1 @@ -52,12 +61,12 @@ def perform_request(self, method, url, params=None, headers=None, body=None): class OpenSearchTestCase(TestCase): def setUp(self) -> None: super(OpenSearchTestCase, self).setUp() - self.client = OpenSearch(transport_class=DummyTransport) + self.client: Any = OpenSearch(transport_class=DummyTransport) # type: ignore - def assert_call_count_equals(self, count) -> None: + def assert_call_count_equals(self, count: int) -> None: self.assertEqual(count, self.client.transport.call_count) - def assert_url_called(self, method, url, count: int = 1): + def assert_url_called(self, method: str, url: str, count: int = 1) -> Any: self.assertIn((method, url), self.client.transport.calls) calls = self.client.transport.calls[(method, url)] self.assertEqual(count, len(calls)) @@ -78,3 +87,6 @@ def test_each_call_is_recorded(self) -> None: self.assertEqual( [({}, None, "body")], self.assert_url_called("DELETE", "/42", 1) ) + + +__all__ = ["SkipTest", "TestCase"] diff --git a/test_opensearchpy/test_client/test_plugins/test_plugins_client.py b/test_opensearchpy/test_client/test_plugins/test_plugins_client.py index d09731bf..ed65dca4 100644 --- a/test_opensearchpy/test_client/test_plugins/test_plugins_client.py +++ b/test_opensearchpy/test_client/test_plugins/test_plugins_client.py @@ -17,7 +17,8 @@ class TestPluginsClient(TestCase): def test_plugins_client(self) -> None: with self.assertWarns(Warning) as w: client = OpenSearch() - client.plugins.__init__(client) # double-init + # double-init + client.plugins.__init__(client) # type: ignore self.assertEqual( str(w.warnings[0].message), "Cannot load `alerting` directly to OpenSearch as it already exists. Use `OpenSearch.plugin.alerting` instead.", diff --git a/test_opensearchpy/test_client/test_utils.py b/test_opensearchpy/test_client/test_utils.py index b6a034eb..797624fc 100644 --- a/test_opensearchpy/test_client/test_utils.py +++ b/test_opensearchpy/test_client/test_utils.py @@ -28,17 +28,19 @@ from __future__ import unicode_literals +from typing import Any + from opensearchpy.client.utils import _bulk_body, _escape, _make_path, query_params from ..test_cases import TestCase class TestQueryParams(TestCase): - def setup_method(self, _) -> None: - self.calls = [] + def setup_method(self, _: Any) -> None: + self.calls: Any = [] @query_params("simple_param") - def func_to_wrap(self, *args, **kwargs) -> None: + def func_to_wrap(self, *args: Any, **kwargs: Any) -> None: self.calls.append((args, kwargs)) def test_handles_params(self) -> None: diff --git a/test_opensearchpy/test_connection/test_base_connection.py b/test_opensearchpy/test_connection/test_base_connection.py index 6ba12d0d..45cc46fd 100644 --- a/test_opensearchpy/test_connection/test_base_connection.py +++ b/test_opensearchpy/test_connection/test_base_connection.py @@ -88,7 +88,7 @@ def test_raises_warnings_when_folded(self) -> None: self.assertEqual([str(w.message) for w in warn], ["warning", "folded"]) - def test_ipv6_host_and_port(self): + def test_ipv6_host_and_port(self) -> None: for kwargs, expected_host in [ ({"host": "::1"}, "http://[::1]:9200"), ({"host": "::1", "port": 443}, "http://[::1]:443"), @@ -96,7 +96,7 @@ def test_ipv6_host_and_port(self): ({"host": "127.0.0.1", "port": 1234}, "http://127.0.0.1:1234"), ({"host": "localhost", "use_ssl": True}, "https://localhost:9200"), ]: - conn = Connection(**kwargs) + conn = Connection(**kwargs) # type: ignore assert conn.host == expected_host def test_compatibility_accept_header(self) -> None: diff --git a/test_opensearchpy/test_connection/test_requests_http_connection.py b/test_opensearchpy/test_connection/test_requests_http_connection.py index 409981f0..7043ec54 100644 --- a/test_opensearchpy/test_connection/test_requests_http_connection.py +++ b/test_opensearchpy/test_connection/test_requests_http_connection.py @@ -30,6 +30,7 @@ import re import uuid import warnings +from typing import Any import pytest from mock import Mock, patch @@ -49,24 +50,27 @@ class TestRequestsHttpConnection(TestCase): def _get_mock_connection( - self, connection_params={}, status_code: int = 200, response_body: bytes = b"{}" - ): + self, + connection_params: Any = {}, + status_code: int = 200, + response_body: bytes = b"{}", + ) -> Any: con = RequestsHttpConnection(**connection_params) - def _dummy_send(*args, **kwargs): + def _dummy_send(*args: Any, **kwargs: Any) -> Any: dummy_response = Mock() dummy_response.headers = {} dummy_response.status_code = status_code dummy_response.content = response_body dummy_response.request = args[0] dummy_response.cookies = {} - _dummy_send.call_args = (args, kwargs) + _dummy_send.call_args = (args, kwargs) # type: ignore return dummy_response - con.session.send = _dummy_send + con.session.send = _dummy_send # type: ignore return con - def _get_request(self, connection, *args, **kwargs): + def _get_request(self, connection: Any, *args: Any, **kwargs: Any) -> Any: if "body" in kwargs: kwargs["body"] = kwargs["body"].encode("utf-8") @@ -237,14 +241,14 @@ def test_request_error_is_returned_on_400(self) -> None: self.assertRaises(RequestError, con.perform_request, "GET", "/", {}, "") @patch("opensearchpy.connection.base.logger") - def test_head_with_404_doesnt_get_logged(self, logger) -> None: + def test_head_with_404_doesnt_get_logged(self, logger: Any) -> None: con = self._get_mock_connection(status_code=404) self.assertRaises(NotFoundError, con.perform_request, "HEAD", "/", {}, "") self.assertEqual(0, logger.warning.call_count) @patch("opensearchpy.connection.base.tracer") @patch("opensearchpy.connection.base.logger") - def test_failed_request_logs_and_traces(self, logger, tracer) -> None: + def test_failed_request_logs_and_traces(self, logger: Any, tracer: Any) -> None: con = self._get_mock_connection( response_body=b'{"answer": 42}', status_code=500 ) @@ -272,7 +276,7 @@ def test_failed_request_logs_and_traces(self, logger, tracer) -> None: @patch("opensearchpy.connection.base.tracer") @patch("opensearchpy.connection.base.logger") - def test_success_logs_and_traces(self, logger, tracer) -> None: + def test_success_logs_and_traces(self, logger: Any, tracer: Any) -> None: con = self._get_mock_connection(response_body=b"""{"answer": "that's it!"}""") status, headers, data = con.perform_request( "GET", @@ -311,7 +315,7 @@ def test_success_logs_and_traces(self, logger, tracer) -> None: self.assertEqual('< {"answer": "that\'s it!"}', resp[0][0] % resp[0][1:]) @patch("opensearchpy.connection.base.logger") - def test_uncompressed_body_logged(self, logger) -> None: + def test_uncompressed_body_logged(self, logger: Any) -> None: con = self._get_mock_connection(connection_params={"http_compress": True}) con.perform_request("GET", "/", body=b'{"example": "body"}') @@ -366,7 +370,7 @@ def test_http_auth_attached(self) -> None: self.assertEqual(request.headers["authorization"], "Basic dXNlcm5hbWU6c2VjcmV0") @patch("opensearchpy.connection.base.tracer") - def test_url_prefix(self, tracer) -> None: + def test_url_prefix(self, tracer: Any) -> None: con = self._get_mock_connection({"url_prefix": "/some-prefix/"}) request = self._get_request( con, "GET", "/_search", body='{"answer": 42}', timeout=0.1 @@ -392,16 +396,16 @@ def test_surrogatepass_into_bytes(self) -> None: def test_recursion_error_reraised(self) -> None: conn = RequestsHttpConnection() - def send_raise(*_, **__): + def send_raise(*_: Any, **__: Any) -> Any: raise RecursionError("Wasn't modified!") - conn.session.send = send_raise + conn.session.send = send_raise # type: ignore with pytest.raises(RecursionError) as e: conn.perform_request("GET", "/") assert str(e.value) == "Wasn't modified!" - def mock_session(self): + def mock_session(self) -> Any: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex @@ -448,7 +452,7 @@ def test_aws_signer_when_service_is_specified(self) -> None: self.assertIn("X-Amz-Security-Token", prepared_request.headers) @patch("opensearchpy.helpers.signer.AWSV4Signer.sign") - def test_aws_signer_signs_with_query_string(self, mock_sign) -> None: + def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None: region = "us-west-1" service = "aoss" @@ -469,6 +473,9 @@ def test_aws_signer_signs_with_query_string(self, mock_sign) -> None: class TestRequestsConnectionRedirect: + server1: TestHTTPServer + server2: TestHTTPServer + @classmethod def setup_class(cls) -> None: # Start servers @@ -505,7 +512,7 @@ def test_redirect_success_when_allow_redirect_true(self) -> None: class TestSignerWithFrozenCredentials(TestRequestsHttpConnection): - def mock_session(self): + def mock_session(self) -> Any: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex diff --git a/test_opensearchpy/test_connection/test_urllib3_http_connection.py b/test_opensearchpy/test_connection/test_urllib3_http_connection.py index c87d8ac0..9720283b 100644 --- a/test_opensearchpy/test_connection/test_urllib3_http_connection.py +++ b/test_opensearchpy/test_connection/test_urllib3_http_connection.py @@ -32,6 +32,7 @@ from gzip import GzipFile from io import BytesIO from platform import python_version +from typing import Any import pytest import urllib3 @@ -45,15 +46,17 @@ class TestUrllib3HttpConnection(TestCase): - def _get_mock_connection(self, connection_params={}, response_body: bytes = b"{}"): + def _get_mock_connection( + self, connection_params: Any = {}, response_body: bytes = b"{}" + ) -> Any: con = Urllib3HttpConnection(**connection_params) - def _dummy_urlopen(*args, **kwargs): + def _dummy_urlopen(*args: Any, **kwargs: Any) -> Any: dummy_response = Mock() dummy_response.headers = HTTPHeaderDict({}) dummy_response.status = 200 dummy_response.data = response_body - _dummy_urlopen.call_args = (args, kwargs) + _dummy_urlopen.call_args = (args, kwargs) # type: ignore return dummy_response con.pool.urlopen = _dummy_urlopen @@ -181,7 +184,7 @@ def test_http_auth_list(self) -> None: "urllib3.HTTPConnectionPool.urlopen", return_value=Mock(status=200, headers=HTTPHeaderDict({}), data=b"{}"), ) - def test_aws_signer_as_http_auth_adds_headers(self, mock_open) -> None: + def test_aws_signer_as_http_auth_adds_headers(self, mock_open: Any) -> None: from opensearchpy.helpers.signer import Urllib3AWSV4SignerAuth auth = Urllib3AWSV4SignerAuth(self.mock_session(), "us-west-2") @@ -247,7 +250,7 @@ def test_aws_signer_when_service_is_specified(self) -> None: self.assertIn("X-Amz-Date", headers) self.assertIn("X-Amz-Security-Token", headers) - def mock_session(self): + def mock_session(self) -> Any: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex @@ -290,6 +293,7 @@ def test_no_warning_when_using_ssl_context(self) -> None: self.assertEqual(0, len(w)) def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self) -> None: + kwargs: Any for kwargs in ( {"ssl_show_warn": False}, {"ssl_show_warn": True}, @@ -325,7 +329,7 @@ def test_uses_no_ca_certs(self) -> None: self.assertIsNone(c.pool.ca_certs) @patch("opensearchpy.connection.base.logger") - def test_uncompressed_body_logged(self, logger) -> None: + def test_uncompressed_body_logged(self, logger: Any) -> None: con = self._get_mock_connection(connection_params={"http_compress": True}) con.perform_request("GET", "/", body=b'{"example": "body"}') @@ -344,7 +348,7 @@ def test_surrogatepass_into_bytes(self) -> None: def test_recursion_error_reraised(self) -> None: conn = Urllib3HttpConnection() - def urlopen_raise(*_, **__): + def urlopen_raise(*_: Any, **__: Any) -> Any: raise RecursionError("Wasn't modified!") conn.pool.urlopen = urlopen_raise @@ -355,7 +359,7 @@ def urlopen_raise(*_, **__): class TestSignerWithFrozenCredentials(TestUrllib3HttpConnection): - def mock_session(self): + def mock_session(self) -> Any: access_key = uuid.uuid4().hex secret_key = uuid.uuid4().hex token = uuid.uuid4().hex diff --git a/test_opensearchpy/test_connection_pool.py b/test_opensearchpy/test_connection_pool.py index 5630030e..45afd93e 100644 --- a/test_opensearchpy/test_connection_pool.py +++ b/test_opensearchpy/test_connection_pool.py @@ -27,6 +27,7 @@ import time +from typing import Any from opensearchpy.connection import Connection from opensearchpy.connection_pool import ( @@ -57,7 +58,7 @@ def test_default_round_robin(self) -> None: connections.add(pool.get_connection()) self.assertEqual(connections, set(range(100))) - def test_disable_shuffling(self): + def test_disable_shuffling(self) -> None: pool = ConnectionPool([(x, {}) for x in range(100)], randomize_hosts=False) connections = [] @@ -65,9 +66,9 @@ def test_disable_shuffling(self): connections.append(pool.get_connection()) self.assertEqual(connections, list(range(100))) - def test_selectors_have_access_to_connection_opts(self): + def test_selectors_have_access_to_connection_opts(self) -> None: class MySelector(RoundRobinSelector): - def select(self, connections): + def select(self, connections: Any) -> Any: return self.connection_opts[ super(MySelector, self).select(connections) ]["actual"] diff --git a/test_opensearchpy/test_helpers/conftest.py b/test_opensearchpy/test_helpers/conftest.py index 9c93ccd0..09778000 100644 --- a/test_opensearchpy/test_helpers/conftest.py +++ b/test_opensearchpy/test_helpers/conftest.py @@ -26,24 +26,26 @@ # under the License. +from typing import Any + from mock import Mock from pytest import fixture from opensearchpy.connection.connections import add_connection, connections -@fixture -def mock_client(dummy_response): +@fixture # type: ignore +def mock_client(dummy_response: Any) -> Any: client = Mock() client.search.return_value = dummy_response add_connection("mock", client) yield client - connections._conn = {} + connections._conns = {} connections._kwargs = {} -@fixture -def dummy_response(): +@fixture # type: ignore +def dummy_response() -> Any: return { "_shards": {"failed": 0, "successful": 10, "total": 10}, "hits": { @@ -91,8 +93,8 @@ def dummy_response(): } -@fixture -def aggs_search(): +@fixture # type: ignore +def aggs_search() -> Any: from opensearchpy import Search s = Search(index="flat-git") @@ -106,8 +108,8 @@ def aggs_search(): return s -@fixture -def aggs_data(): +@fixture # type: ignore +def aggs_data() -> Any: return { "took": 4, "timed_out": False, diff --git a/test_opensearchpy/test_helpers/test_actions.py b/test_opensearchpy/test_helpers/test_actions.py index 35b78d9a..739e8647 100644 --- a/test_opensearchpy/test_helpers/test_actions.py +++ b/test_opensearchpy/test_helpers/test_actions.py @@ -28,6 +28,7 @@ import threading import time +from typing import Any import mock import pytest @@ -40,19 +41,19 @@ lock_side_effect = threading.Lock() -def mock_process_bulk_chunk(*args, **kwargs): +def mock_process_bulk_chunk(*args: Any, **kwargs: Any) -> Any: """ Threadsafe way of mocking process bulk chunk: https://stackoverflow.com/questions/39332139/thread-safe-version-of-mock-call-count """ with lock_side_effect: - mock_process_bulk_chunk.call_count += 1 + mock_process_bulk_chunk.call_count += 1 # type: ignore time.sleep(0.1) return [] -mock_process_bulk_chunk.call_count = 0 +mock_process_bulk_chunk.call_count = 0 # type: ignore class TestParallelBulk(TestCase): @@ -60,21 +61,21 @@ class TestParallelBulk(TestCase): "opensearchpy.helpers.actions._process_bulk_chunk", side_effect=mock_process_bulk_chunk, ) - def test_all_chunks_sent(self, _process_bulk_chunk) -> None: + def test_all_chunks_sent(self, _process_bulk_chunk: Any) -> None: actions = ({"x": i} for i in range(100)) list(helpers.parallel_bulk(OpenSearch(), actions, chunk_size=2)) - self.assertEqual(50, mock_process_bulk_chunk.call_count) + self.assertEqual(50, mock_process_bulk_chunk.call_count) # type: ignore - @pytest.mark.skip + @pytest.mark.skip # type: ignore @mock.patch( "opensearchpy.helpers.actions._process_bulk_chunk", # make sure we spend some time in the thread side_effect=lambda *a: [ - (True, time.sleep(0.001) or threading.current_thread().ident) + (True, time.sleep(0.001) or threading.current_thread().ident) # type: ignore ], ) - def test_chunk_sent_from_different_threads(self, _process_bulk_chunk) -> None: + def test_chunk_sent_from_different_threads(self, _process_bulk_chunk: Any) -> None: actions = ({"x": i} for i in range(100)) results = list( helpers.parallel_bulk(OpenSearch(), actions, thread_count=10, chunk_size=2) @@ -83,8 +84,8 @@ def test_chunk_sent_from_different_threads(self, _process_bulk_chunk) -> None: class TestChunkActions(TestCase): - def setup_method(self, _) -> None: - self.actions = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] # fmt: skip + def setup_method(self, _: Any) -> None: + self.actions: Any = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] # fmt: skip def test_expand_action(self) -> None: self.assertEqual(helpers.expand_action({}), ({"index": {}}, {})) @@ -92,7 +93,7 @@ def test_expand_action(self) -> None: helpers.expand_action({"key": "val"}), ({"index": {}}, {"key": "val"}) ) - def test_expand_action_actions(self): + def test_expand_action_actions(self) -> None: self.assertEqual( helpers.expand_action( {"_op_type": "delete", "_id": "id", "_index": "index"} @@ -154,7 +155,7 @@ def test_expand_action_options(self) -> None: ({"index": {action_option: 0}}, {"key": "val"}), ) - def test__source_metadata_or_source(self): + def test__source_metadata_or_source(self) -> None: self.assertEqual( helpers.expand_action({"_source": {"key": "val"}}), ({"index": {}}, {"key": "val"}), diff --git a/test_opensearchpy/test_helpers/test_aggs.py b/test_opensearchpy/test_helpers/test_aggs.py index f46dd132..8a23c218 100644 --- a/test_opensearchpy/test_helpers/test_aggs.py +++ b/test_opensearchpy/test_helpers/test_aggs.py @@ -37,7 +37,7 @@ def test_repr() -> None: assert "Terms(aggs={'max_score': Max(field='score')}, field='tags')" == repr(a) -def test_meta(): +def test_meta() -> None: max_score = aggs.Max(field="score") a = aggs.A( "terms", field="tags", aggs={"max_score": max_score}, meta={"some": "metadata"} @@ -66,7 +66,7 @@ def test_A_creates_proper_agg() -> None: assert a._params == {"field": "tags"} -def test_A_handles_nested_aggs_properly(): +def test_A_handles_nested_aggs_properly() -> None: max_score = aggs.Max(field="score") a = aggs.A("terms", field="tags", aggs={"max_score": max_score}) @@ -79,7 +79,7 @@ def test_A_passes_aggs_through() -> None: assert aggs.A(a) is a -def test_A_from_dict(): +def test_A_from_dict() -> None: d = { "terms": {"field": "tags"}, "aggs": {"per_author": {"terms": {"field": "author.raw"}}}, @@ -95,7 +95,7 @@ def test_A_from_dict(): assert a.aggs.per_author == aggs.A("terms", field="author.raw") -def test_A_fails_with_incorrect_dict(): +def test_A_fails_with_incorrect_dict() -> None: correct_d = { "terms": {"field": "tags"}, "aggs": {"per_author": {"terms": {"field": "author.raw"}}}, @@ -148,7 +148,7 @@ def test_buckets_equals_counts_subaggs() -> None: assert a != b -def test_buckets_to_dict(): +def test_buckets_to_dict() -> None: a = aggs.Terms(field="tags") a.bucket("per_author", "terms", field="author.raw") @@ -189,7 +189,7 @@ def test_filter_can_be_instantiated_using_positional_args() -> None: assert a == aggs.A("filter", query.Q("term", f=42)) -def test_filter_aggregation_as_nested_agg(): +def test_filter_aggregation_as_nested_agg() -> None: a = aggs.Terms(field="tags") a.bucket("filtered", "filter", query.Q("term", f=42)) @@ -199,7 +199,7 @@ def test_filter_aggregation_as_nested_agg(): } == a.to_dict() -def test_filter_aggregation_with_nested_aggs(): +def test_filter_aggregation_with_nested_aggs() -> None: a = aggs.Filter(query.Q("term", f=42)) a.bucket("testing", "terms", field="tags") @@ -229,7 +229,7 @@ def test_filters_correctly_identifies_the_hash() -> None: assert a.filters.group_a == query.Q("term", group="a") -def test_bucket_sort_agg(): +def test_bucket_sort_agg() -> None: bucket_sort_agg = aggs.BucketSort(sort=[{"total_sales": {"order": "desc"}}], size=3) assert bucket_sort_agg.to_dict() == { "bucket_sort": {"sort": [{"total_sales": {"order": "desc"}}], "size": 3} @@ -254,7 +254,7 @@ def test_bucket_sort_agg(): } == a.to_dict() -def test_bucket_sort_agg_only_trnunc(): +def test_bucket_sort_agg_only_trnunc() -> None: bucket_sort_agg = aggs.BucketSort(**{"from": 1, "size": 1}) assert bucket_sort_agg.to_dict() == {"bucket_sort": {"from": 1, "size": 1}} @@ -284,7 +284,7 @@ def test_boxplot_aggregation() -> None: assert {"boxplot": {"field": "load_time"}} == a.to_dict() -def test_rare_terms_aggregation(): +def test_rare_terms_aggregation() -> None: a = aggs.RareTerms(field="the-field") a.bucket("total_sales", "sum", field="price") a.bucket( @@ -316,7 +316,7 @@ def test_median_absolute_deviation_aggregation() -> None: assert {"median_absolute_deviation": {"field": "rating"}} == a.to_dict() -def test_t_test_aggregation(): +def test_t_test_aggregation() -> None: a = aggs.TTest( a={"field": "startup_time_before"}, b={"field": "startup_time_after"}, @@ -332,14 +332,14 @@ def test_t_test_aggregation(): } == a.to_dict() -def test_inference_aggregation(): +def test_inference_aggregation() -> None: a = aggs.Inference(model_id="model-id", buckets_path={"agg_name": "agg_name"}) assert { "inference": {"buckets_path": {"agg_name": "agg_name"}, "model_id": "model-id"} } == a.to_dict() -def test_moving_percentiles_aggregation(): +def test_moving_percentiles_aggregation() -> None: a = aggs.DateHistogram() a.bucket("the_percentile", "percentiles", field="price", percents=[1.0, 99.0]) a.pipeline( diff --git a/test_opensearchpy/test_helpers/test_analysis.py b/test_opensearchpy/test_helpers/test_analysis.py index 7b8f6b04..0226ee48 100644 --- a/test_opensearchpy/test_helpers/test_analysis.py +++ b/test_opensearchpy/test_helpers/test_analysis.py @@ -36,7 +36,7 @@ def test_analyzer_serializes_as_name() -> None: assert "my_analyzer" == a.to_dict() -def test_analyzer_has_definition(): +def test_analyzer_has_definition() -> None: a = analysis.CustomAnalyzer( "my_analyzer", tokenizer="keyword", filter=["lowercase"] ) @@ -48,7 +48,7 @@ def test_analyzer_has_definition(): } == a.get_definition() -def test_simple_multiplexer_filter(): +def test_simple_multiplexer_filter() -> None: a = analysis.analyzer( "my_analyzer", tokenizer="keyword", @@ -76,7 +76,7 @@ def test_simple_multiplexer_filter(): } == a.get_analysis_definition() -def test_multiplexer_with_custom_filter(): +def test_multiplexer_with_custom_filter() -> None: a = analysis.analyzer( "my_analyzer", tokenizer="keyword", @@ -107,7 +107,7 @@ def test_multiplexer_with_custom_filter(): } == a.get_analysis_definition() -def test_conditional_token_filter(): +def test_conditional_token_filter() -> None: a = analysis.analyzer( "my_cond", tokenizer=analysis.tokenizer("keyword"), @@ -172,7 +172,7 @@ def test_normalizer_serializes_as_name() -> None: assert "my_normalizer" == n.to_dict() -def test_normalizer_has_definition(): +def test_normalizer_has_definition() -> None: n = analysis.CustomNormalizer( "my_normalizer", filter=["lowercase", "asciifolding"], char_filter=["quote"] ) @@ -191,7 +191,7 @@ def test_tokenizer() -> None: assert {"type": "nGram", "min_gram": 3, "max_gram": 3} == t.get_definition() -def test_custom_analyzer_can_collect_custom_items(): +def test_custom_analyzer_can_collect_custom_items() -> None: trigram = analysis.tokenizer("trigram", "nGram", min_gram=3, max_gram=3) my_stop = analysis.token_filter("my_stop", "stop", stopwords=["a", "b"]) umlauts = analysis.char_filter("umlauts", "pattern_replace", mappings=["ü=>ue"]) diff --git a/test_opensearchpy/test_helpers/test_document.py b/test_opensearchpy/test_helpers/test_document.py index ed78b4c0..e1b5e5c4 100644 --- a/test_opensearchpy/test_helpers/test_document.py +++ b/test_opensearchpy/test_helpers/test_document.py @@ -32,6 +32,7 @@ import pickle from datetime import datetime from hashlib import sha256 +from typing import Any from pytest import raises @@ -52,7 +53,7 @@ class MyDoc(document.Document): class MySubDoc(MyDoc): - name = field.Keyword() + name: Any = field.Keyword() class Index: name = "default-index" @@ -92,10 +93,10 @@ class Secret(str): class SecretField(field.CustomField): builtin_type = "text" - def _serialize(self, data): + def _serialize(self, data: Any) -> Any: return codecs.encode(data, "rot_13") - def _deserialize(self, data): + def _deserialize(self, data: Any) -> Any: if isinstance(data, Secret): return data return Secret(codecs.decode(data, "rot_13")) @@ -114,6 +115,8 @@ class NestedSecret(document.Document): class Index: name = "test-nested-secret" + _index: Any + class OptionalObjectWithRequiredField(document.Document): comments = field.Nested(properties={"title": field.Keyword(required=True)}) @@ -121,6 +124,8 @@ class OptionalObjectWithRequiredField(document.Document): class Index: name = "test-required" + _index: Any + class Host(document.Document): ip = field.Ip() @@ -128,12 +133,14 @@ class Host(document.Document): class Index: name = "test-host" + _index: Any + def test_range_serializes_properly() -> None: class D(document.Document): lr = field.LongRange() - d = D(lr=Range(lt=42)) + d: Any = D(lr=Range(lt=42)) assert 40 in d.lr assert 47 not in d.lr assert {"lr": {"lt": 42}} == d.to_dict() @@ -146,7 +153,7 @@ def test_range_deserializes_properly() -> None: class D(document.InnerDoc): lr = field.LongRange() - d = D.from_opensearch({"lr": {"lt": 42}}, True) + d: Any = D.from_opensearch({"lr": {"lt": 42}}, True) assert isinstance(d.lr, Range) assert 40 in d.lr assert 47 not in d.lr @@ -165,7 +172,7 @@ class A(document.Document): class B(document.Document): name = field.Keyword() - i = Index("i") + i: Any = Index("i") i.document(A) i.document(B) @@ -174,7 +181,7 @@ class B(document.Document): def test_ip_address_serializes_properly() -> None: - host = Host(ip=ipaddress.IPv4Address("10.0.0.1")) + host: Any = Host(ip=ipaddress.IPv4Address("10.0.0.1")) assert {"ip": "10.0.0.1"} == host.to_dict() @@ -202,7 +209,7 @@ class Index: def test_assigning_attrlist_to_field() -> None: - sc = SimpleCommit() + sc: Any = SimpleCommit() ls = ["README", "README.rst"] sc.files = utils.AttrList(ls) @@ -210,20 +217,20 @@ def test_assigning_attrlist_to_field() -> None: def test_optional_inner_objects_are_not_validated_if_missing() -> None: - d = OptionalObjectWithRequiredField() + d: Any = OptionalObjectWithRequiredField() assert d.full_clean() is None def test_custom_field() -> None: - s = SecretDoc(title=Secret("Hello")) + s1: Any = SecretDoc(title=Secret("Hello")) - assert {"title": "Uryyb"} == s.to_dict() - assert s.title == "Hello" + assert {"title": "Uryyb"} == s1.to_dict() + assert s1.title == "Hello" - s = SecretDoc.from_opensearch({"_source": {"title": "Uryyb"}}) - assert s.title == "Hello" - assert isinstance(s.title, Secret) + s2: Any = SecretDoc.from_opensearch({"_source": {"title": "Uryyb"}}) + assert s2.title == "Hello" + assert isinstance(s2.title, Secret) def test_custom_field_mapping() -> None: @@ -233,7 +240,7 @@ def test_custom_field_mapping() -> None: def test_custom_field_in_nested() -> None: - s = NestedSecret() + s: Any = NestedSecret() s.secrets.append(SecretDoc(title=Secret("Hello"))) assert {"secrets": [{"title": "Uryyb"}]} == s.to_dict() @@ -241,7 +248,7 @@ def test_custom_field_in_nested() -> None: def test_multi_works_after_doc_has_been_saved() -> None: - c = SimpleCommit() + c: Any = SimpleCommit() c.full_clean() c.files.append("setup.py") @@ -250,7 +257,7 @@ def test_multi_works_after_doc_has_been_saved() -> None: def test_multi_works_in_nested_after_doc_has_been_serialized() -> None: # Issue #359 - c = DocWithNested(comments=[Comment(title="First!")]) + c: Any = DocWithNested(comments=[Comment(title="First!")]) assert [] == c.comments[0].tags assert {"comments": [{"title": "First!"}]} == c.to_dict() @@ -258,17 +265,19 @@ def test_multi_works_in_nested_after_doc_has_been_serialized() -> None: def test_null_value_for_object() -> None: - d = MyDoc(inner=None) + d: Any = MyDoc(inner=None) assert d.inner is None -def test_inherited_doc_types_can_override_index(): +def test_inherited_doc_types_can_override_index() -> None: class MyDocDifferentIndex(MySubDoc): + _index: Any + class Index: name = "not-default-index" settings = {"number_of_replicas": 0} - aliases = {"a": {}} + aliases: Any = {"a": {}} analyzers = [analyzer("my_analizer", tokenizer="keyword")] assert MyDocDifferentIndex._index._name == "not-default-index" @@ -295,8 +304,8 @@ class Index: } -def test_to_dict_with_meta(): - d = MySubDoc(title="hello") +def test_to_dict_with_meta() -> None: + d: Any = MySubDoc(title="hello") d.meta.routing = "some-parent" assert { @@ -306,29 +315,29 @@ def test_to_dict_with_meta(): } == d.to_dict(True) -def test_to_dict_with_meta_includes_custom_index(): - d = MySubDoc(title="hello") +def test_to_dict_with_meta_includes_custom_index() -> None: + d: Any = MySubDoc(title="hello") d.meta.index = "other-index" assert {"_index": "other-index", "_source": {"title": "hello"}} == d.to_dict(True) def test_to_dict_without_skip_empty_will_include_empty_fields() -> None: - d = MySubDoc(tags=[], title=None, inner={}) + d: Any = MySubDoc(tags=[], title=None, inner={}) assert {} == d.to_dict() assert {"tags": [], "title": None, "inner": {}} == d.to_dict(skip_empty=False) def test_attribute_can_be_removed() -> None: - d = MyDoc(title="hello") + d: Any = MyDoc(title="hello") del d.title assert "title" not in d._d_ def test_doc_type_can_be_correctly_pickled() -> None: - d = DocWithNested( + d: Any = DocWithNested( title="Hello World!", comments=[Comment(title="hellp")], meta={"id": 42} ) s = pickle.dumps(d) @@ -343,14 +352,14 @@ def test_doc_type_can_be_correctly_pickled() -> None: def test_meta_is_accessible_even_on_empty_doc() -> None: - d = MyDoc() - d.meta + d1: Any = MyDoc() + d1.meta - d = MyDoc(title="aaa") - d.meta + d2: Any = MyDoc(title="aaa") + d2.meta -def test_meta_field_mapping(): +def test_meta_field_mapping() -> None: class User(document.Document): username = field.Text() @@ -373,7 +382,7 @@ def test_multi_value_fields() -> None: class Blog(document.Document): tags = field.Keyword(multi=True) - b = Blog() + b: Any = Blog() assert [] == b.tags b.tags.append("search") b.tags.append("python") @@ -382,20 +391,20 @@ class Blog(document.Document): def test_docs_with_properties() -> None: class User(document.Document): - pwd_hash = field.Text() + pwd_hash: Any = field.Text() - def check_password(self, pwd): + def check_password(self, pwd: Any) -> Any: return sha256(pwd).hexdigest() == self.pwd_hash @property - def password(self): + def password(self) -> Any: raise AttributeError("readonly") @password.setter - def password(self, pwd): + def password(self, pwd: Any) -> None: self.pwd_hash = sha256(pwd).hexdigest() - u = User(pwd_hash=sha256(b"secret").hexdigest()) + u: Any = User(pwd_hash=sha256(b"secret").hexdigest()) assert u.check_password(b"secret") assert not u.check_password(b"not-secret") @@ -409,8 +418,8 @@ def password(self, pwd): def test_nested_can_be_assigned_to() -> None: - d1 = DocWithNested(comments=[Comment(title="First!")]) - d2 = DocWithNested() + d1: Any = DocWithNested(comments=[Comment(title="First!")]) + d2: Any = DocWithNested() d2.comments = d1.comments assert isinstance(d1.comments[0], Comment) @@ -420,13 +429,13 @@ def test_nested_can_be_assigned_to() -> None: def test_nested_can_be_none() -> None: - d = DocWithNested(comments=None, title="Hello World!") + d: Any = DocWithNested(comments=None, title="Hello World!") assert {"title": "Hello World!"} == d.to_dict() def test_nested_defaults_to_list_and_can_be_updated() -> None: - md = DocWithNested() + md: Any = DocWithNested() assert [] == md.comments @@ -434,8 +443,8 @@ def test_nested_defaults_to_list_and_can_be_updated() -> None: assert {"comments": [{"title": "hello World!"}]} == md.to_dict() -def test_to_dict_is_recursive_and_can_cope_with_multi_values(): - md = MyDoc(name=["a", "b", "c"]) +def test_to_dict_is_recursive_and_can_cope_with_multi_values() -> None: + md: Any = MyDoc(name=["a", "b", "c"]) md.inner = [MyInner(old_field="of1"), MyInner(old_field="of2")] assert isinstance(md.inner[0], MyInner) @@ -447,12 +456,12 @@ def test_to_dict_is_recursive_and_can_cope_with_multi_values(): def test_to_dict_ignores_empty_collections() -> None: - md = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) + md: Any = MySubDoc(name="", address={}, count=0, valid=False, tags=[]) assert {"name": "", "count": 0, "valid": False} == md.to_dict() -def test_declarative_mapping_definition(): +def test_declarative_mapping_definition() -> None: assert issubclass(MyDoc, document.Document) assert hasattr(MyDoc, "_doc_type") assert { @@ -465,7 +474,7 @@ def test_declarative_mapping_definition(): } == MyDoc._doc_type.mapping.to_dict() -def test_you_can_supply_own_mapping_instance(): +def test_you_can_supply_own_mapping_instance() -> None: class MyD(document.Document): title = field.Text() @@ -479,9 +488,9 @@ class Meta: } == MyD._doc_type.mapping.to_dict() -def test_document_can_be_created_dynamically(): +def test_document_can_be_created_dynamically() -> None: n = datetime.now() - md = MyDoc(title="hello") + md: Any = MyDoc(title="hello") md.name = "My Fancy Document!" md.created_at = n @@ -501,13 +510,13 @@ def test_document_can_be_created_dynamically(): def test_invalid_date_will_raise_exception() -> None: - md = MyDoc() + md: Any = MyDoc() md.created_at = "not-a-date" with raises(ValidationException): md.full_clean() -def test_document_inheritance(): +def test_document_inheritance() -> None: assert issubclass(MySubDoc, MyDoc) assert issubclass(MySubDoc, document.Document) assert hasattr(MySubDoc, "_doc_type") @@ -521,7 +530,7 @@ def test_document_inheritance(): } == MySubDoc._doc_type.mapping.to_dict() -def test_child_class_can_override_parent(): +def test_child_class_can_override_parent() -> None: class A(document.Document): o = field.Object(dynamic=False, properties={"a": field.Text()}) @@ -540,7 +549,7 @@ class B(A): def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: - md = MySubDoc(meta={"id": 42}, name="My First doc!") + md: Any = MySubDoc(meta={"id": 42}, name="My First doc!") md.meta.index = "my-index" assert md.meta.index == "my-index" @@ -549,7 +558,7 @@ def test_meta_fields_are_stored_in_meta_and_ignored_by_to_dict() -> None: assert {"id": 42, "index": "my-index"} == md.meta.to_dict() -def test_index_inheritance(): +def test_index_inheritance() -> None: assert issubclass(MyMultiSubDoc, MySubDoc) assert issubclass(MyMultiSubDoc, MyDoc2) assert issubclass(MyMultiSubDoc, document.Document) @@ -568,31 +577,31 @@ def test_index_inheritance(): def test_meta_fields_can_be_set_directly_in_init() -> None: p = object() - md = MyDoc(_id=p, title="Hello World!") + md: Any = MyDoc(_id=p, title="Hello World!") assert md.meta.id is p -def test_save_no_index(mock_client) -> None: - md = MyDoc() +def test_save_no_index(mock_client: Any) -> None: + md: Any = MyDoc() with raises(ValidationException): md.save(using="mock") -def test_delete_no_index(mock_client) -> None: - md = MyDoc() +def test_delete_no_index(mock_client: Any) -> None: + md: Any = MyDoc() with raises(ValidationException): md.delete(using="mock") def test_update_no_fields() -> None: - md = MyDoc() + md: Any = MyDoc() with raises(IllegalOperation): md.update() -def test_search_with_custom_alias_and_index(mock_client) -> None: - search_object = MyDoc.search( +def test_search_with_custom_alias_and_index(mock_client: Any) -> None: + search_object: Any = MyDoc.search( using="staging", index=["custom_index1", "custom_index2"] ) @@ -600,7 +609,7 @@ def test_search_with_custom_alias_and_index(mock_client) -> None: assert search_object._index == ["custom_index1", "custom_index2"] -def test_from_opensearch_respects_underscored_non_meta_fields(): +def test_from_opensearch_respects_underscored_non_meta_fields() -> None: doc = { "_index": "test-index", "_id": "opensearch", @@ -617,18 +626,18 @@ class Company(document.Document): class Index: name = "test-company" - c = Company.from_opensearch(doc) + c: Any = Company.from_opensearch(doc) assert c.meta.fields._tags == ["search"] assert c.meta.fields._routing == "opensearch" assert c._tagline == "You know, for search" -def test_nested_and_object_inner_doc(): +def test_nested_and_object_inner_doc() -> None: class MySubDocWithNested(MyDoc): nested_inner = field.Nested(MyInner) - props = MySubDocWithNested._doc_type.mapping.to_dict()["properties"] + props: Any = MySubDocWithNested._doc_type.mapping.to_dict()["properties"] assert props == { "created_at": {"type": "date"}, "inner": {"properties": {"old_field": {"type": "text"}}, "type": "object"}, diff --git a/test_opensearchpy/test_helpers/test_faceted_search.py b/test_opensearchpy/test_helpers/test_faceted_search.py index e663bca1..93716ce1 100644 --- a/test_opensearchpy/test_helpers/test_faceted_search.py +++ b/test_opensearchpy/test_helpers/test_faceted_search.py @@ -26,6 +26,7 @@ # under the License. from datetime import datetime +from typing import Any import pytest @@ -72,7 +73,7 @@ def test_query_is_created_properly() -> None: } == s.to_dict() -def test_query_is_created_properly_with_sort_tuple(): +def test_query_is_created_properly_with_sort_tuple() -> None: bs = BlogSearch("python search", sort=("category", "-title")) s = bs.build_search() @@ -96,7 +97,7 @@ def test_query_is_created_properly_with_sort_tuple(): } == s.to_dict() -def test_filter_is_applied_to_search_but_not_relevant_facet(): +def test_filter_is_applied_to_search_but_not_relevant_facet() -> None: bs = BlogSearch("python search", filters={"category": "opensearch"}) s = bs.build_search() @@ -119,7 +120,7 @@ def test_filter_is_applied_to_search_but_not_relevant_facet(): } == s.to_dict() -def test_filters_are_applied_to_search_ant_relevant_facets(): +def test_filters_are_applied_to_search_ant_relevant_facets() -> None: bs = BlogSearch( "python search", filters={"category": "opensearch", "tags": ["python", "django"]}, @@ -185,8 +186,8 @@ def test_date_histogram_facet_with_1970_01_01_date() -> None: ("interval", "1h"), ("fixed_interval", "1h"), ], -) -def test_date_histogram_interval_types(interval_type, interval) -> None: +) # type: ignore +def test_date_histogram_interval_types(interval_type: Any, interval: Any) -> None: dhf = DateHistogramFacet(field="@timestamp", **{interval_type: interval}) assert dhf.get_aggregation().to_dict() == { "date_histogram": { diff --git a/test_opensearchpy/test_helpers/test_field.py b/test_opensearchpy/test_helpers/test_field.py index 19582730..ce818b50 100644 --- a/test_opensearchpy/test_helpers/test_field.py +++ b/test_opensearchpy/test_helpers/test_field.py @@ -28,6 +28,7 @@ import base64 from datetime import datetime from ipaddress import ip_address +from typing import Any import pytest from dateutil import tz @@ -59,7 +60,7 @@ def test_boolean_deserialization() -> None: def test_date_field_can_have_default_tz() -> None: - f = field.Date(default_timezone="UTC") + f: Any = field.Date(default_timezone="UTC") now = datetime.now() now_with_tz = f._deserialize(now) @@ -76,7 +77,7 @@ def test_date_field_can_have_default_tz() -> None: def test_custom_field_car_wrap_other_field() -> None: class MyField(field.CustomField): @property - def builtin_type(self): + def builtin_type(self) -> Any: return field.Text(**self._params) assert {"type": "text", "index": "not_analyzed"} == MyField( @@ -91,7 +92,7 @@ def test_field_from_dict() -> None: assert {"type": "text", "index": "not_analyzed"} == f.to_dict() -def test_multi_fields_are_accepted_and_parsed(): +def test_multi_fields_are_accepted_and_parsed() -> None: f = field.construct_field( "text", fields={"raw": {"type": "keyword"}, "eng": field.Text(analyzer="english")}, @@ -123,7 +124,7 @@ def test_field_supports_multiple_analyzers() -> None: } == f.to_dict() -def test_multifield_supports_multiple_analyzers(): +def test_multifield_supports_multiple_analyzers() -> None: f = field.Text( fields={ "f1": field.Text(search_analyzer="keyword", analyzer="snowball"), @@ -145,8 +146,8 @@ def test_multifield_supports_multiple_analyzers(): def test_scaled_float() -> None: with pytest.raises(TypeError): - field.ScaledFloat() - f = field.ScaledFloat(123) + field.ScaledFloat() # type: ignore + f: Any = field.ScaledFloat(scaling_factor=123) assert f.to_dict() == {"scaling_factor": 123, "type": "scaled_float"} @@ -204,7 +205,7 @@ def test_object_disabled() -> None: assert f.to_dict() == {"type": "object", "enabled": False} -def test_object_constructor(): +def test_object_constructor() -> None: expected = {"type": "object", "properties": {"inner_int": {"type": "integer"}}} class Inner(InnerDoc): diff --git a/test_opensearchpy/test_helpers/test_index.py b/test_opensearchpy/test_helpers/test_index.py index bb8aa578..59c3e28e 100644 --- a/test_opensearchpy/test_helpers/test_index.py +++ b/test_opensearchpy/test_helpers/test_index.py @@ -27,6 +27,7 @@ import string from random import choice +from typing import Any from pytest import raises @@ -65,7 +66,7 @@ def test_search_is_limited_to_index_name() -> None: def test_cloned_index_has_copied_settings_and_using() -> None: client = object() - i = Index("my-index", using=client) + i: Any = Index("my-index", using=client) i.settings(number_of_shards=1) i2 = i.clone("my-other-index") @@ -82,7 +83,7 @@ def test_cloned_index_has_analysis_attribute() -> None: over the `_analysis` attribute. """ client = object() - i = Index("my-index", using=client) + i: Any = Index("my-index", using=client) random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer = analyzer( @@ -97,7 +98,7 @@ def test_cloned_index_has_analysis_attribute() -> None: def test_settings_are_saved() -> None: - i = Index("i") + i: Any = Index("i") i.settings(number_of_replicas=0) i.settings(number_of_shards=1) @@ -105,7 +106,7 @@ def test_settings_are_saved() -> None: def test_registered_doc_type_included_in_to_dict() -> None: - i = Index("i", using="alias") + i: Any = Index("i", using="alias") i.document(Post) assert { @@ -119,7 +120,7 @@ def test_registered_doc_type_included_in_to_dict() -> None: def test_registered_doc_type_included_in_search() -> None: - i = Index("i", using="alias") + i: Any = Index("i", using="alias") i.document(Post) s = i.search() @@ -129,9 +130,9 @@ def test_registered_doc_type_included_in_search() -> None: def test_aliases_add_to_object() -> None: random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) - alias_dict = {random_alias: {}} + alias_dict: Any = {random_alias: {}} - index = Index("i", using="alias") + index: Any = Index("i", using="alias") index.aliases(**alias_dict) assert index._aliases == alias_dict @@ -139,21 +140,21 @@ def test_aliases_add_to_object() -> None: def test_aliases_returned_from_to_dict() -> None: random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) - alias_dict = {random_alias: {}} + alias_dict: Any = {random_alias: {}} - index = Index("i", using="alias") + index: Any = Index("i", using="alias") index.aliases(**alias_dict) assert index._aliases == index.to_dict()["aliases"] == alias_dict -def test_analyzers_added_to_object(): +def test_analyzers_added_to_object() -> None: random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" ) - index = Index("i", using="alias") + index: Any = Index("i", using="alias") index.analyzer(random_analyzer) assert index._analysis["analyzer"][random_analyzer_name] == { @@ -163,12 +164,12 @@ def test_analyzers_added_to_object(): } -def test_analyzers_returned_from_to_dict(): +def test_analyzers_returned_from_to_dict() -> None: random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer = analyzer( random_analyzer_name, tokenizer="standard", filter="standard" ) - index = Index("i", using="alias") + index: Any = Index("i", using="alias") index.analyzer(random_analyzer) assert index.to_dict()["settings"]["analysis"]["analyzer"][ @@ -177,21 +178,21 @@ def test_analyzers_returned_from_to_dict(): def test_conflicting_analyzer_raises_error() -> None: - i = Index("i") + i: Any = Index("i") i.analyzer("my_analyzer", tokenizer="whitespace", filter=["lowercase", "stop"]) with raises(ValueError): i.analyzer("my_analyzer", tokenizer="keyword", filter=["lowercase", "stop"]) -def test_index_template_can_have_order(): - i = Index("i-*") +def test_index_template_can_have_order() -> None: + i: Any = Index("i-*") it = i.as_template("i", order=2) assert {"index_patterns": ["i-*"], "order": 2} == it.to_dict() -def test_index_template_save_result(mock_client) -> None: - it = IndexTemplate("test-template", "test-*") +def test_index_template_save_result(mock_client: Any) -> None: + it: Any = IndexTemplate("test-template", "test-*") assert it.save(using="mock") == mock_client.indices.put_template() diff --git a/test_opensearchpy/test_helpers/test_mapping.py b/test_opensearchpy/test_helpers/test_mapping.py index 5e4e49ce..2006b66f 100644 --- a/test_opensearchpy/test_helpers/test_mapping.py +++ b/test_opensearchpy/test_helpers/test_mapping.py @@ -40,7 +40,7 @@ def test_mapping_can_has_fields() -> None: } == m.to_dict() -def test_mapping_update_is_recursive(): +def test_mapping_update_is_recursive() -> None: m1 = mapping.Mapping() m1.field("title", "text") m1.field("author", "object") @@ -83,7 +83,7 @@ def test_properties_can_iterate_over_all_the_fields() -> None: } -def test_mapping_can_collect_all_analyzers_and_normalizers(): +def test_mapping_can_collect_all_analyzers_and_normalizers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", @@ -156,7 +156,7 @@ def test_mapping_can_collect_all_analyzers_and_normalizers(): assert json.loads(json.dumps(m.to_dict())) == m.to_dict() -def test_mapping_can_collect_multiple_analyzers(): +def test_mapping_can_collect_multiple_analyzers() -> None: a1 = analysis.analyzer( "my_analyzer1", tokenizer="keyword", diff --git a/test_opensearchpy/test_helpers/test_query.py b/test_opensearchpy/test_helpers/test_query.py index 142b865c..27790748 100644 --- a/test_opensearchpy/test_helpers/test_query.py +++ b/test_opensearchpy/test_helpers/test_query.py @@ -25,6 +25,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from pytest import raises from opensearchpy.helpers import function, query @@ -122,8 +124,8 @@ def test_other_and_bool_appends_other_to_must() -> None: def test_bool_and_other_appends_other_to_must() -> None: - q1 = query.Match(f="value1") - qb = query.Bool() + q1: Any = query.Match(f="value1") + qb: Any = query.Bool() q = qb & q1 assert q is not qb @@ -463,7 +465,7 @@ def test_function_score_with_functions() -> None: } == q.to_dict() -def test_function_score_with_no_function_is_boost_factor(): +def test_function_score_with_no_function_is_boost_factor() -> None: q = query.Q( "function_score", functions=[query.SF({"weight": 20, "filter": query.Q("term", f=42)})], @@ -474,7 +476,7 @@ def test_function_score_with_no_function_is_boost_factor(): } == q.to_dict() -def test_function_score_to_dict(): +def test_function_score_to_dict() -> None: q = query.Q( "function_score", query=query.Q("match", title="python"), @@ -503,7 +505,7 @@ def test_function_score_to_dict(): assert d == q.to_dict() -def test_function_score_with_single_function(): +def test_function_score_with_single_function() -> None: d = { "function_score": { "filter": {"term": {"tags": "python"}}, @@ -521,7 +523,7 @@ def test_function_score_with_single_function(): assert "doc['comment_count'] * _score" == sf.script -def test_function_score_from_dict(): +def test_function_score_from_dict() -> None: d = { "function_score": { "filter": {"term": {"tags": "python"}}, diff --git a/test_opensearchpy/test_helpers/test_result.py b/test_opensearchpy/test_helpers/test_result.py index 657beb05..296553f3 100644 --- a/test_opensearchpy/test_helpers/test_result.py +++ b/test_opensearchpy/test_helpers/test_result.py @@ -27,6 +27,7 @@ import pickle from datetime import date +from typing import Any from pytest import fixture, raises @@ -36,12 +37,12 @@ from opensearchpy.helpers.response.aggs import AggResponse, Bucket, BucketData -@fixture -def agg_response(aggs_search, aggs_data): +@fixture # type: ignore +def agg_response(aggs_search: Any, aggs_data: Any) -> Any: return response.Response(aggs_search, aggs_data) -def test_agg_response_is_pickleable(agg_response) -> None: +def test_agg_response_is_pickleable(agg_response: Any) -> None: agg_response.hits r = pickle.loads(pickle.dumps(agg_response)) @@ -50,7 +51,7 @@ def test_agg_response_is_pickleable(agg_response) -> None: assert r.hits == agg_response.hits -def test_response_is_pickleable(dummy_response) -> None: +def test_response_is_pickleable(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) res.hits r = pickle.loads(pickle.dumps(res)) @@ -60,7 +61,7 @@ def test_response_is_pickleable(dummy_response) -> None: assert r.hits == res.hits -def test_hit_is_pickleable(dummy_response) -> None: +def test_hit_is_pickleable(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) hits = pickle.loads(pickle.dumps(res.hits)) @@ -68,14 +69,14 @@ def test_hit_is_pickleable(dummy_response) -> None: assert hits[0].meta == res.hits[0].meta -def test_response_stores_search(dummy_response) -> None: +def test_response_stores_search(dummy_response: Any) -> None: s = Search() r = response.Response(s, dummy_response) assert r._search is s -def test_interactive_helpers(dummy_response) -> None: +def test_interactive_helpers(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) hits = res.hits h = hits[0] @@ -98,19 +99,19 @@ def test_interactive_helpers(dummy_response) -> None: ] == repr(h) -def test_empty_response_is_false(dummy_response) -> None: +def test_empty_response_is_false(dummy_response: Any) -> None: dummy_response["hits"]["hits"] = [] res = response.Response(Search(), dummy_response) assert not res -def test_len_response(dummy_response) -> None: +def test_len_response(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) assert len(res) == 4 -def test_iterating_over_response_gives_you_hits(dummy_response) -> None: +def test_iterating_over_response_gives_you_hits(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) hits = list(h for h in res) @@ -127,7 +128,7 @@ def test_iterating_over_response_gives_you_hits(dummy_response) -> None: assert hits[1].meta.routing == "opensearch" -def test_hits_get_wrapped_to_contain_additional_attrs(dummy_response) -> None: +def test_hits_get_wrapped_to_contain_additional_attrs(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) hits = res.hits @@ -135,7 +136,7 @@ def test_hits_get_wrapped_to_contain_additional_attrs(dummy_response) -> None: assert 12.0 == hits.max_score -def test_hits_provide_dot_and_bracket_access_to_attrs(dummy_response) -> None: +def test_hits_provide_dot_and_bracket_access_to_attrs(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) h = res.hits[0] @@ -151,30 +152,32 @@ def test_hits_provide_dot_and_bracket_access_to_attrs(dummy_response) -> None: h.not_there -def test_slicing_on_response_slices_on_hits(dummy_response) -> None: +def test_slicing_on_response_slices_on_hits(dummy_response: Any) -> None: res = response.Response(Search(), dummy_response) assert res[0] is res.hits[0] assert res[::-1] == res.hits[::-1] -def test_aggregation_base(agg_response) -> None: +def test_aggregation_base(agg_response: Any) -> None: assert agg_response.aggs is agg_response.aggregations assert isinstance(agg_response.aggs, response.AggResponse) -def test_metric_agg_works(agg_response) -> None: +def test_metric_agg_works(agg_response: Any) -> None: assert 25052.0 == agg_response.aggs.sum_lines.value -def test_aggregations_can_be_iterated_over(agg_response) -> None: +def test_aggregations_can_be_iterated_over(agg_response: Any) -> None: aggs = [a for a in agg_response.aggs] assert len(aggs) == 3 assert all(map(lambda a: isinstance(a, AggResponse), aggs)) -def test_aggregations_can_be_retrieved_by_name(agg_response, aggs_search) -> None: +def test_aggregations_can_be_retrieved_by_name( + agg_response: Any, aggs_search: Any +) -> None: a = agg_response.aggs["popular_files"] assert isinstance(a, BucketData) @@ -182,7 +185,7 @@ def test_aggregations_can_be_retrieved_by_name(agg_response, aggs_search) -> Non assert a._meta["aggs"] is aggs_search.aggs.aggs["popular_files"] -def test_bucket_response_can_be_iterated_over(agg_response) -> None: +def test_bucket_response_can_be_iterated_over(agg_response: Any) -> None: popular_files = agg_response.aggregations.popular_files buckets = [b for b in popular_files] @@ -190,7 +193,7 @@ def test_bucket_response_can_be_iterated_over(agg_response) -> None: assert buckets == popular_files.buckets -def test_bucket_keys_get_deserialized(aggs_data, aggs_search) -> None: +def test_bucket_keys_get_deserialized(aggs_data: Any, aggs_search: Any) -> None: class Commit(Document): info = Object(properties={"committed_date": Date()}) diff --git a/test_opensearchpy/test_helpers/test_search.py b/test_opensearchpy/test_helpers/test_search.py index 73d078a9..b44d5dd5 100644 --- a/test_opensearchpy/test_helpers/test_search.py +++ b/test_opensearchpy/test_helpers/test_search.py @@ -26,6 +26,7 @@ # under the License. from copy import deepcopy +from typing import Any from pytest import raises @@ -41,16 +42,16 @@ def test_expand__to_dot_is_respected() -> None: def test_execute_uses_cache() -> None: - s = search.Search() - r = object() + s: Any = search.Search() + r: Any = object() s._response = r assert r is s.execute() -def test_cache_can_be_ignored(mock_client) -> None: - s = search.Search(using="mock") - r = object() +def test_cache_can_be_ignored(mock_client: Any) -> None: + s: Any = search.Search(using="mock") + r: Any = object() s._response = r s.execute(ignore_cache=True) @@ -58,27 +59,27 @@ def test_cache_can_be_ignored(mock_client) -> None: def test_iter_iterates_over_hits() -> None: - s = search.Search() + s: Any = search.Search() s._response = [1, 2, 3] assert [1, 2, 3] == list(s) def test_cache_isnt_cloned() -> None: - s = search.Search() + s: Any = search.Search() s._response = object() assert not hasattr(s._clone(), "_response") def test_search_starts_with_no_query() -> None: - s = search.Search() + s: Any = search.Search() assert s.query._proxied is None def test_search_query_combines_query() -> None: - s = search.Search() + s: Any = search.Search() s2 = s.query("match", f=42) assert s2.query._proxied == query.Match(f=42) @@ -90,7 +91,7 @@ def test_search_query_combines_query() -> None: def test_query_can_be_assigned_to() -> None: - s = search.Search() + s: Any = search.Search() q = Q("match", title="python") s.query = q @@ -98,8 +99,8 @@ def test_query_can_be_assigned_to() -> None: assert s.query._proxied is q -def test_query_can_be_wrapped(): - s = search.Search().query("match", title="python") +def test_query_can_be_wrapped() -> None: + s: Any = search.Search().query("match", title="python") s.query = Q("function_score", query=s.query, field_value_factor={"field": "rating"}) @@ -114,9 +115,9 @@ def test_query_can_be_wrapped(): def test_using() -> None: - o = object() - o2 = object() - s = search.Search(using=o) + o: Any = object() + o2: Any = object() + s: Any = search.Search(using=o) assert s._using is o s2 = s.using(o2) assert s._using is o @@ -124,27 +125,27 @@ def test_using() -> None: def test_methods_are_proxied_to_the_query() -> None: - s = search.Search().query("match_all") + s: Any = search.Search().query("match_all") assert s.query.to_dict() == {"match_all": {}} def test_query_always_returns_search() -> None: - s = search.Search() + s: Any = search.Search() assert isinstance(s.query("match", f=42), search.Search) def test_source_copied_on_clone() -> None: - s = search.Search().source(False) + s: Any = search.Search().source(False) assert s._clone()._source == s._source assert s._clone()._source is False - s2 = search.Search().source([]) + s2: Any = search.Search().source([]) assert s2._clone()._source == s2._source assert s2._source == [] - s3 = search.Search().source(["some", "fields"]) + s3: Any = search.Search().source(["some", "fields"]) assert s3._clone()._source == s3._source assert s3._clone()._source == ["some", "fields"] @@ -152,15 +153,15 @@ def test_source_copied_on_clone() -> None: def test_copy_clones() -> None: from copy import copy - s1 = search.Search().source(["some", "fields"]) - s2 = copy(s1) + s1: Any = search.Search().source(["some", "fields"]) + s2: Any = copy(s1) assert s1 == s2 assert s1 is not s2 def test_aggs_allow_two_metric() -> None: - s = search.Search() + s: Any = search.Search() s.aggs.metric("a", "max", field="a").metric("b", "max", field="b") @@ -169,8 +170,8 @@ def test_aggs_allow_two_metric() -> None: } -def test_aggs_get_copied_on_change(): - s = search.Search().query("match_all") +def test_aggs_get_copied_on_change() -> None: + s: Any = search.Search().query("match_all") s.aggs.bucket("per_tag", "terms", field="f").metric( "max_score", "max", field="score" ) @@ -182,7 +183,7 @@ def test_aggs_get_copied_on_change(): s4 = s3._clone() s4.aggs.metric("max_score", "max", field="score") - d = { + d: Any = { "query": {"match_all": {}}, "aggs": { "per_tag": { @@ -245,7 +246,7 @@ class MyDocument(Document): assert s._doc_type_map == {} -def test_sort(): +def test_sort() -> None: s = search.Search() s = s.sort("fielda", "-fieldb") @@ -267,7 +268,7 @@ def test_sort_by_score() -> None: s.sort("-_score") -def test_collapse(): +def test_collapse() -> None: s = search.Search() inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]} @@ -315,7 +316,7 @@ def test_index() -> None: assert {"from": 3, "size": 1} == s[3].to_dict() -def test_search_to_dict(): +def test_search_to_dict() -> None: s = search.Search() assert {} == s.to_dict() @@ -344,7 +345,7 @@ def test_search_to_dict(): assert {"size": 5, "from": 42} == s.to_dict() -def test_complex_example(): +def test_complex_example() -> None: s = search.Search() s = ( s.query("match", title="python") @@ -395,7 +396,7 @@ def test_complex_example(): } == s.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -451,7 +452,7 @@ def test_from_dict_doesnt_need_query() -> None: assert {"size": 5} == s.to_dict() -def test_params_being_passed_to_search(mock_client) -> None: +def test_params_being_passed_to_search(mock_client: Any) -> None: s = search.Search(using="mock") s = s.params(routing="42") s.execute() @@ -473,7 +474,7 @@ def test_source() -> None: ).source(["f1", "f2"]).to_dict() -def test_source_on_clone(): +def test_source_on_clone() -> None: assert { "_source": {"includes": ["foo.bar.*"], "excludes": ["foo.one"]}, "query": {"bool": {"filter": [{"term": {"title": "python"}}]}}, @@ -498,7 +499,7 @@ def test_source_on_clear() -> None: ) -def test_suggest_accepts_global_text(): +def test_suggest_accepts_global_text() -> None: s = search.Search.from_dict( { "suggest": { @@ -520,7 +521,7 @@ def test_suggest_accepts_global_text(): } == s.to_dict() -def test_suggest(): +def test_suggest() -> None: s = search.Search() s = s.suggest("my_suggestion", "pyhton", term={"field": "title"}) @@ -542,7 +543,7 @@ def test_exclude() -> None: } == s.to_dict() -def test_delete_by_query(mock_client) -> None: +def test_delete_by_query(mock_client: Any) -> None: s = search.Search(using="mock").query("match", lang="java") s.delete() @@ -551,7 +552,7 @@ def test_delete_by_query(mock_client) -> None: ) -def test_update_from_dict(): +def test_update_from_dict() -> None: s = search.Search() s.update_from_dict({"indices_boost": [{"important-documents": 2}]}) s.update_from_dict({"_source": ["id", "name"]}) @@ -562,7 +563,7 @@ def test_update_from_dict(): } == s.to_dict() -def test_rescore_query_to_dict(): +def test_rescore_query_to_dict() -> None: s = search.Search(index="index-name") positive_query = Q( diff --git a/test_opensearchpy/test_helpers/test_update_by_query.py b/test_opensearchpy/test_helpers/test_update_by_query.py index 74030874..90e7aa78 100644 --- a/test_opensearchpy/test_helpers/test_update_by_query.py +++ b/test_opensearchpy/test_helpers/test_update_by_query.py @@ -26,6 +26,7 @@ # under the License. from copy import deepcopy +from typing import Any from opensearchpy import Q, UpdateByQuery from opensearchpy.helpers.response import UpdateByQueryResponse @@ -37,7 +38,7 @@ def test_ubq_starts_with_no_query() -> None: assert ubq.query._proxied is None -def test_ubq_to_dict(): +def test_ubq_to_dict() -> None: ubq = UpdateByQuery() assert {} == ubq.to_dict() @@ -53,7 +54,7 @@ def test_ubq_to_dict(): assert {"extra_q": {"term": {"category": "conference"}}} == ubq.to_dict() -def test_complex_example(): +def test_complex_example() -> None: ubq = UpdateByQuery() ubq = ( ubq.query("match", title="python") @@ -104,7 +105,7 @@ def test_exclude() -> None: } == ubq.to_dict() -def test_reverse(): +def test_reverse() -> None: d = { "query": { "filtered": { @@ -146,7 +147,7 @@ def test_from_dict_doesnt_need_query() -> None: assert {"script": {"source": "test"}} == ubq.to_dict() -def test_params_being_passed_to_search(mock_client) -> None: +def test_params_being_passed_to_search(mock_client: Any) -> None: ubq = UpdateByQuery(using="mock") ubq = ubq.params(routing="42") ubq.execute() @@ -156,7 +157,7 @@ def test_params_being_passed_to_search(mock_client) -> None: ) -def test_overwrite_script(): +def test_overwrite_script() -> None: ubq = UpdateByQuery() ubq = ubq.script( source="ctx._source.likes += params.f", lang="painless", params={"f": 3} diff --git a/test_opensearchpy/test_helpers/test_utils.py b/test_opensearchpy/test_helpers/test_utils.py index 358b9184..b6949833 100644 --- a/test_opensearchpy/test_helpers/test_utils.py +++ b/test_opensearchpy/test_helpers/test_utils.py @@ -55,7 +55,7 @@ class MyAttrDict(utils.AttrDict): assert isinstance(ls[:][0], MyAttrDict) -def test_merge(): +def test_merge() -> None: a = utils.AttrDict({"a": {"b": 42, "c": 47}}) b = {"a": {"b": 123, "d": -12}, "e": [1, 2, 3]} @@ -101,7 +101,7 @@ def test_serializer_deals_with_Attr_versions() -> None: def test_serializer_deals_with_objects_with_to_dict() -> None: class MyClass(object): - def to_dict(self): + def to_dict(self) -> int: return 42 assert serializer.serializer.dumps(MyClass()) == "42" diff --git a/test_opensearchpy/test_helpers/test_validation.py b/test_opensearchpy/test_helpers/test_validation.py index 1565b352..6841f604 100644 --- a/test_opensearchpy/test_helpers/test_validation.py +++ b/test_opensearchpy/test_helpers/test_validation.py @@ -26,6 +26,7 @@ # under the License. from datetime import datetime +from typing import Any from pytest import raises @@ -43,8 +44,8 @@ class Author(InnerDoc): - name = Text(required=True) - email = Text(required=True) + name: Any = Text(required=True) + email: Any = Text(required=True) def clean(self) -> None: print(self, type(self), self.name) @@ -63,7 +64,7 @@ class BlogPostWithStatus(Document): class AutoNowDate(Date): - def clean(self, data): + def clean(self, data: Any) -> Any: if data is None: data = datetime.now() return super(AutoNowDate, self).clean(data) @@ -78,7 +79,7 @@ def test_required_int_can_be_0() -> None: class DT(Document): i = Integer(required=True) - dt = DT(i=0) + dt: Any = DT(i=0) assert dt.full_clean() is None @@ -95,12 +96,12 @@ def test_validation_works_for_lists_of_values() -> None: class DT(Document): i = Date(required=True) - dt = DT(i=[datetime.now(), "not date"]) + dt1: Any = DT(i=[datetime.now(), "not date"]) with raises(ValidationException): - dt.full_clean() + dt1.full_clean() - dt = DT(i=[datetime.now(), datetime.now()]) - assert None is dt.full_clean() + dt2: Any = DT(i=[datetime.now(), datetime.now()]) + assert None is dt2.full_clean() def test_field_with_custom_clean() -> None: @@ -111,29 +112,29 @@ def test_field_with_custom_clean() -> None: def test_empty_object() -> None: - d = BlogPost(authors=[{"name": "Guian", "email": "guiang@bitquilltech.com"}]) + d: Any = BlogPost(authors=[{"name": "Guian", "email": "guiang@bitquilltech.com"}]) d.inner = {} d.full_clean() def test_missing_required_field_raises_validation_exception() -> None: - d = BlogPost() + d1: Any = BlogPost() with raises(ValidationException): - d.full_clean() + d1.full_clean() - d = BlogPost() - d.authors.append({"name": "Guian"}) + d2: Any = BlogPost() + d2.authors.append({"name": "Guian"}) with raises(ValidationException): - d.full_clean() + d2.full_clean() - d = BlogPost() - d.authors.append({"name": "Guian", "email": "guiang@bitquilltech.com"}) - d.full_clean() + d3: Any = BlogPost() + d3.authors.append({"name": "Guian", "email": "guiang@bitquilltech.com"}) + d3.full_clean() def test_boolean_doesnt_treat_false_as_empty() -> None: - d = BlogPostWithStatus() + d: Any = BlogPostWithStatus() with raises(ValidationException): d.full_clean() d.published = False @@ -143,7 +144,9 @@ def test_boolean_doesnt_treat_false_as_empty() -> None: def test_custom_validation_on_nested_gets_run() -> None: - d = BlogPost(authors=[Author(name="Guian", email="king@example.com")], created=None) + d: Any = BlogPost( + authors=[Author(name="Guian", email="king@example.com")], created=None + ) assert isinstance(d.authors[0], Author) @@ -152,7 +155,7 @@ def test_custom_validation_on_nested_gets_run() -> None: def test_accessing_known_fields_returns_empty_value() -> None: - d = BlogPost() + d: Any = BlogPost() assert [] == d.authors @@ -162,7 +165,7 @@ def test_accessing_known_fields_returns_empty_value() -> None: def test_empty_values_are_not_serialized() -> None: - d = BlogPost( + d: Any = BlogPost( authors=[{"name": "Guian", "email": "guiang@bitquilltech.com"}], created=None ) diff --git a/test_opensearchpy/test_helpers/test_wrappers.py b/test_opensearchpy/test_helpers/test_wrappers.py index 2212b070..2f9bacba 100644 --- a/test_opensearchpy/test_helpers/test_wrappers.py +++ b/test_opensearchpy/test_helpers/test_wrappers.py @@ -26,6 +26,7 @@ # under the License. from datetime import datetime, timedelta +from typing import Any import pytest @@ -43,8 +44,8 @@ ({"lte": 4, "gte": 2}, 2), ({"gt": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], -) -def test_range_contains(kwargs, item) -> None: +) # type: ignore +def test_range_contains(kwargs: Any, item: Any) -> None: assert item in Range(**kwargs) @@ -57,8 +58,8 @@ def test_range_contains(kwargs, item) -> None: ({"lte": 4, "gte": 2}, 1), ({"lte": datetime.now() - timedelta(seconds=10)}, datetime.now()), ], -) -def test_range_not_contains(kwargs, item): +) # type: ignore +def test_range_not_contains(kwargs: Any, item: Any) -> None: assert item not in Range(**kwargs) @@ -71,8 +72,8 @@ def test_range_not_contains(kwargs, item): ((), {"lt": 1, "lte": 1}), ((), {"gt": 1, "gte": 1}), ], -) -def test_range_raises_value_error_on_wrong_params(args, kwargs) -> None: +) # type: ignore +def test_range_raises_value_error_on_wrong_params(args: Any, kwargs: Any) -> None: with pytest.raises(ValueError): Range(*args, **kwargs) @@ -85,8 +86,8 @@ def test_range_raises_value_error_on_wrong_params(args, kwargs) -> None: (Range(), None, False), (Range(lt=42), None, False), ], -) -def test_range_lower(range, lower, inclusive) -> None: +) # type: ignore +def test_range_lower(range: Any, lower: Any, inclusive: Any) -> None: assert (lower, inclusive) == range.lower @@ -98,6 +99,6 @@ def test_range_lower(range, lower, inclusive) -> None: (Range(), None, False), (Range(gt=42), None, False), ], -) -def test_range_upper(range, upper, inclusive) -> None: +) # type: ignore +def test_range_upper(range: Any, upper: Any, inclusive: Any) -> None: assert (upper, inclusive) == range.upper diff --git a/test_opensearchpy/test_serializer.py b/test_opensearchpy/test_serializer.py index d7fef3e8..d425fabf 100644 --- a/test_opensearchpy/test_serializer.py +++ b/test_opensearchpy/test_serializer.py @@ -30,6 +30,7 @@ import uuid from datetime import datetime from decimal import Decimal +from typing import Any try: import numpy as np @@ -212,7 +213,7 @@ def test_raises_serialization_error_on_dump_error(self) -> None: class TestDeserializer(TestCase): - def setup_method(self, _) -> None: + def setup_method(self, _: Any) -> None: self.de = Deserializer(DEFAULT_SERIALIZERS) def test_deserializes_json_by_default(self) -> None: diff --git a/test_opensearchpy/test_server/__init__.py b/test_opensearchpy/test_server/__init__.py index d3965fed..2da5aa90 100644 --- a/test_opensearchpy/test_server/__init__.py +++ b/test_opensearchpy/test_server/__init__.py @@ -26,15 +26,16 @@ # under the License. +from typing import Any from unittest import SkipTest from opensearchpy.helpers import test -from opensearchpy.helpers.test import OpenSearchTestCase as BaseTestCase +from test_opensearchpy.test_cases import OpenSearchTestCase as BaseTestCase client = None -def get_client(**kwargs): +def get_client(**kwargs: Any) -> Any: global client if client is False: raise SkipTest("No client is available") @@ -49,7 +50,7 @@ def get_client(**kwargs): except ImportError: # fallback to using vanilla client try: - new_client = test.get_test_client(**kwargs) + new_client = test.get_test_client(**kwargs) # type: ignore except SkipTest: client = False raise @@ -66,5 +67,5 @@ def setup_module() -> None: class OpenSearchTestCase(BaseTestCase): @staticmethod - def _get_client(**kwargs): + def _get_client(**kwargs: Any) -> Any: return get_client(**kwargs) diff --git a/test_opensearchpy/test_server/conftest.py b/test_opensearchpy/test_server/conftest.py index 128c33eb..8459092c 100644 --- a/test_opensearchpy/test_server/conftest.py +++ b/test_opensearchpy/test_server/conftest.py @@ -28,11 +28,12 @@ import os import time +from typing import Any import pytest import opensearchpy -from opensearchpy.helpers.test import OPENSEARCH_URL +from opensearchpy.helpers.test import OPENSEARCH_URL # type: ignore from ..utils import wipe_cluster @@ -40,11 +41,11 @@ # Used for OPENSEARCH_VERSION = "" OPENSEARCH_BUILD_HASH = "" -OPENSEARCH_REST_API_TESTS = [] +OPENSEARCH_REST_API_TESTS: Any = [] -@pytest.fixture(scope="session") -def sync_client_factory(): +@pytest.fixture(scope="session") # type: ignore +def sync_client_factory() -> Any: client = None try: # Configure the client optionally with an HTTP conn class @@ -63,7 +64,7 @@ def sync_client_factory(): # We do this little dance with the URL to force # Requests to respect 'headers: None' within rest API spec tests. client = opensearchpy.OpenSearch( - OPENSEARCH_URL.replace("elastic:changeme@", ""), **kw + OPENSEARCH_URL.replace("elastic:changeme@", ""), **kw # type: ignore ) # Wait for the cluster to report a status of 'yellow' @@ -83,8 +84,8 @@ def sync_client_factory(): client.close() -@pytest.fixture(scope="function") -def sync_client(sync_client_factory): +@pytest.fixture(scope="function") # type: ignore +def sync_client(sync_client_factory: Any) -> Any: try: yield sync_client_factory finally: diff --git a/test_opensearchpy/test_server/test_helpers/conftest.py b/test_opensearchpy/test_server/test_helpers/conftest.py index 8be79616..b9fc06b8 100644 --- a/test_opensearchpy/test_server/test_helpers/conftest.py +++ b/test_opensearchpy/test_server/test_helpers/conftest.py @@ -27,13 +27,13 @@ import re from datetime import datetime +from typing import Any from pytest import fixture -from opensearchpy.client import OpenSearch from opensearchpy.connection.connections import add_connection from opensearchpy.helpers import bulk -from opensearchpy.helpers.test import get_test_client +from opensearchpy.helpers.test import get_test_client # type: ignore from .test_data import ( DATA, @@ -45,32 +45,32 @@ from .test_document import Comment, History, PullRequest, User -@fixture(scope="session") -def client() -> OpenSearch: +@fixture(scope="session") # type: ignore +def client() -> Any: client = get_test_client(verify_certs=False, http_auth=("admin", "admin")) add_connection("default", client) return client -@fixture(scope="session") -def opensearch_version(client): +@fixture(scope="session") # type: ignore +def opensearch_version(client: Any) -> Any: info = client.info() print(info) yield tuple( int(x) - for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") + for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore ) -@fixture -def write_client(client): +@fixture # type: ignore +def write_client(client: Any) -> Any: yield client client.indices.delete("test-*", ignore=404) client.indices.delete_template("test-template", ignore=404) -@fixture(scope="session") -def data_client(client): +@fixture(scope="session") # type: ignore +def data_client(client: Any) -> Any: # create mappings create_git_index(client, "git") create_flat_git_index(client, "flat-git") @@ -82,8 +82,8 @@ def data_client(client): client.indices.delete("flat-git", ignore=404) -@fixture -def pull_request(write_client): +@fixture # type: ignore +def pull_request(write_client: Any) -> Any: PullRequest.init() pr = PullRequest( _id=42, @@ -106,8 +106,8 @@ def pull_request(write_client): return pr -@fixture -def setup_ubq_tests(client) -> str: +@fixture # type: ignore +def setup_ubq_tests(client: Any) -> str: index = "test-git" create_git_index(client, index) bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True) diff --git a/test_opensearchpy/test_server/test_helpers/test_actions.py b/test_opensearchpy/test_server/test_helpers/test_actions.py index 7fb8f234..ddc0af19 100644 --- a/test_opensearchpy/test_server/test_helpers/test_actions.py +++ b/test_opensearchpy/test_server/test_helpers/test_actions.py @@ -26,7 +26,7 @@ # under the License. -from typing import Tuple +from typing import Any from mock import patch @@ -40,9 +40,9 @@ class FailingBulkClient(object): def __init__( self, - client, - fail_at: Tuple[int] = (2,), - fail_with=TransportError(599, "Error!", {}), + client: Any, + fail_at: Any = (2,), + fail_with: Any = TransportError(599, "Error!", {}), ) -> None: self.client = client self._called = 0 @@ -50,7 +50,7 @@ def __init__( self.transport = client.transport self._fail_with = fail_with - def bulk(self, *args, **kwargs): + def bulk(self, *args: Any, **kwargs: Any) -> Any: self._called += 1 if self._called in self._fail_at: raise self._fail_with @@ -98,8 +98,8 @@ def test_all_errors_from_chunk_are_raised_on_failure(self) -> None: else: assert False, "exception should have been raised" - def test_different_op_types(self): - if self.opensearch_version() < (0, 90, 1): + def test_different_op_types(self) -> Any: + if self.opensearch_version() < (0, 90, 1): # type: ignore raise SkipTest("update supported since 0.90.1") self.client.index(index="i", id=45, body={}) self.client.index(index="i", id=42, body={}) @@ -218,7 +218,7 @@ def test_transport_error_is_raised_with_max_retries(self) -> None: fail_with=TransportError(429, "Rejected!", {}), ) - def streaming_bulk(): + def streaming_bulk() -> Any: results = list( helpers.streaming_bulk( failing_client, @@ -271,7 +271,7 @@ def test_stats_only_reports_numbers(self) -> None: self.assertEqual(0, failed) self.assertEqual(100, self.client.count(index="test-index")["count"]) - def test_errors_are_reported_correctly(self): + def test_errors_are_reported_correctly(self) -> None: self.client.indices.create( "i", { @@ -316,7 +316,7 @@ def test_error_is_raised(self) -> None: index="i", ) - def test_ignore_error_if_raised(self): + def test_ignore_error_if_raised(self) -> None: # ignore the status code 400 in tuple helpers.bulk( self.client, [{"a": 42}, {"a": "c"}], index="i", ignore_status=(400,) @@ -349,7 +349,7 @@ def test_ignore_error_if_raised(self): failing_client = FailingBulkClient(self.client) helpers.bulk(failing_client, [{"a": 42}], index="i", ignore_status=(599,)) - def test_errors_are_collected_properly(self): + def test_errors_are_collected_properly(self) -> None: self.client.indices.create( "i", { @@ -384,12 +384,12 @@ class TestScan(OpenSearchTestCase): }, ] - def teardown_method(self, m) -> None: + def teardown_method(self, m: Any) -> None: self.client.transport.perform_request("DELETE", "/_search/scroll/_all") - super(TestScan, self).teardown_method(m) + super(TestScan, self).teardown_method(m) # type: ignore - def test_order_can_be_preserved(self): - bulk = [] + def test_order_can_be_preserved(self) -> None: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append({"answer": x, "correct": x == 42}) @@ -408,8 +408,8 @@ def test_order_can_be_preserved(self): self.assertEqual(list(map(str, range(100))), list(d["_id"] for d in docs)) self.assertEqual(list(range(100)), list(d["_source"]["answer"] for d in docs)) - def test_all_documents_are_read(self): - bulk = [] + def test_all_documents_are_read(self) -> None: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append({"answer": x, "correct": x == 42}) @@ -421,8 +421,8 @@ def test_all_documents_are_read(self): self.assertEqual(set(map(str, range(100))), set(d["_id"] for d in docs)) self.assertEqual(set(range(100)), set(d["_source"]["answer"] for d in docs)) - def test_scroll_error(self): - bulk = [] + def test_scroll_error(self) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -456,7 +456,7 @@ def test_scroll_error(self): self.assertEqual(len(data), 3) self.assertEqual(data[-1], {"scroll_data": 42}) - def test_initial_search_error(self): + def test_initial_search_error(self) -> None: with patch.object(self, "client") as client_mock: client_mock.search.return_value = { "_scroll_id": "dummy_id", @@ -491,7 +491,7 @@ def test_no_scroll_id_fast_route(self) -> None: client_mock.scroll.assert_not_called() client_mock.clear_scroll.assert_not_called() - def test_scan_auth_kwargs_forwarded(self): + def test_scan_auth_kwargs_forwarded(self) -> None: for key, val in { "api_key": ("name", "value"), "http_auth": ("username", "password"), @@ -510,7 +510,11 @@ def test_scan_auth_kwargs_forwarded(self): } client_mock.clear_scroll.return_value = {} - data = list(helpers.scan(self.client, index="test_index", **{key: val})) + data = list( + helpers.scan( + self.client, index="test_index", scroll_kwargs={key: val} + ) + ) self.assertEqual(data, [{"search_data": 1}]) @@ -523,7 +527,7 @@ def test_scan_auth_kwargs_forwarded(self): ): self.assertEqual(api_mock.call_args[1][key], val) - def test_scan_auth_kwargs_favor_scroll_kwargs_option(self): + def test_scan_auth_kwargs_favor_scroll_kwargs_option(self) -> None: with patch.object(self, "client") as client_mock: client_mock.search.return_value = { "_scroll_id": "scroll_id", @@ -555,8 +559,8 @@ def test_scan_auth_kwargs_favor_scroll_kwargs_option(self): self.assertEqual(client_mock.scroll.call_args[1]["sort"], "asc") @patch("opensearchpy.helpers.actions.logger") - def test_logger(self, logger_mock): - bulk = [] + def test_logger(self, logger_mock: Any) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -590,8 +594,8 @@ def test_logger(self, logger_mock): pass logger_mock.warning.assert_called() - def test_clear_scroll(self): - bulk = [] + def test_clear_scroll(self) -> None: + bulk: Any = [] for x in range(4): bulk.append({"index": {"_index": "test_index"}}) bulk.append({"value": x}) @@ -617,7 +621,7 @@ def test_clear_scroll(self): ) spy.assert_not_called() - def test_shards_no_skipped_field(self): + def test_shards_no_skipped_field(self) -> None: with patch.object(self, "client") as client_mock: client_mock.search.return_value = { "_scroll_id": "dummy_id", @@ -646,8 +650,8 @@ def test_shards_no_skipped_field(self): class TestReindex(OpenSearchTestCase): - def setup_method(self, _): - bulk = [] + def setup_method(self, _: Any) -> None: + bulk: Any = [] for x in range(100): bulk.append({"index": {"_index": "test_index", "_id": x}}) bulk.append( @@ -716,7 +720,7 @@ def test_all_documents_get_moved(self) -> None: class TestParentChildReindex(OpenSearchTestCase): - def setup_method(self, _): + def setup_method(self, _: Any) -> None: body = { "settings": {"number_of_shards": 1, "number_of_replicas": 0}, "mappings": { diff --git a/test_opensearchpy/test_server/test_helpers/test_analysis.py b/test_opensearchpy/test_server/test_helpers/test_analysis.py index 2da9388a..e965e05b 100644 --- a/test_opensearchpy/test_server/test_helpers/test_analysis.py +++ b/test_opensearchpy/test_server/test_helpers/test_analysis.py @@ -25,10 +25,12 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from opensearchpy import analyzer, token_filter, tokenizer -def test_simulate_with_just__builtin_tokenizer(client) -> None: +def test_simulate_with_just__builtin_tokenizer(client: Any) -> None: a = analyzer("my-analyzer", tokenizer="keyword") tokens = a.simulate("Hello World!", using=client).tokens @@ -36,7 +38,7 @@ def test_simulate_with_just__builtin_tokenizer(client) -> None: assert tokens[0].token == "Hello World!" -def test_simulate_complex(client) -> None: +def test_simulate_complex(client: Any) -> None: a = analyzer( "my-analyzer", tokenizer=tokenizer("split_words", "simple_pattern_split", pattern=":"), @@ -49,7 +51,7 @@ def test_simulate_complex(client) -> None: assert ["this", "works"] == [t.token for t in tokens] -def test_simulate_builtin(client) -> None: +def test_simulate_builtin(client: Any) -> None: a = analyzer("my-analyzer", "english") tokens = a.simulate("fixes running").tokens diff --git a/test_opensearchpy/test_server/test_helpers/test_count.py b/test_opensearchpy/test_server/test_helpers/test_count.py index 7bf9c27e..65f424d1 100644 --- a/test_opensearchpy/test_server/test_helpers/test_count.py +++ b/test_opensearchpy/test_server/test_helpers/test_count.py @@ -25,15 +25,17 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from opensearchpy.helpers.search import Q, Search -def test_count_all(data_client) -> None: +def test_count_all(data_client: Any) -> None: s = Search(using=data_client).index("git") assert 53 == s.count() -def test_count_prefetch(data_client, mocker) -> None: +def test_count_prefetch(data_client: Any, mocker: Any) -> None: mocker.spy(data_client, "count") search = Search(using=data_client).index("git") @@ -46,7 +48,7 @@ def test_count_prefetch(data_client, mocker) -> None: assert data_client.count.call_count == 1 -def test_count_filter(data_client) -> None: +def test_count_filter(data_client: Any) -> None: s = Search(using=data_client).index("git").filter(~Q("exists", field="parent_shas")) # initial commit + repo document assert 2 == s.count() diff --git a/test_opensearchpy/test_server/test_helpers/test_data.py b/test_opensearchpy/test_server/test_helpers/test_data.py index 63302b7a..11ad915f 100644 --- a/test_opensearchpy/test_server/test_helpers/test_data.py +++ b/test_opensearchpy/test_server/test_helpers/test_data.py @@ -30,7 +30,7 @@ from typing import Any, Dict -def create_flat_git_index(client, index): +def create_flat_git_index(client: Any, index: Any) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -73,7 +73,7 @@ def create_flat_git_index(client, index): ) -def create_git_index(client, index): +def create_git_index(client: Any, index: Any) -> None: # we will use user on several places user_mapping = { "properties": {"name": {"type": "text", "fields": {"raw": {"type": "keyword"}}}} @@ -1095,7 +1095,7 @@ def create_git_index(client, index): ] -def flatten_doc(d) -> Dict[str, Any]: +def flatten_doc(d: Any) -> Dict[str, Any]: src = d["_source"].copy() del src["commit_repo"] return {"_index": "flat-git", "_id": d["_id"], "_source": src} @@ -1104,7 +1104,7 @@ def flatten_doc(d) -> Dict[str, Any]: FLAT_DATA = [flatten_doc(d) for d in DATA if "routing" in d] -def create_test_git_data(d) -> Dict[str, Any]: +def create_test_git_data(d: Any) -> Dict[str, Any]: src = d["_source"].copy() return { "_index": "test-git", diff --git a/test_opensearchpy/test_server/test_helpers/test_document.py b/test_opensearchpy/test_server/test_helpers/test_document.py index 0da4b856..ad0bf289 100644 --- a/test_opensearchpy/test_server/test_helpers/test_document.py +++ b/test_opensearchpy/test_server/test_helpers/test_document.py @@ -27,6 +27,7 @@ from datetime import datetime from ipaddress import ip_address +from typing import Any import pytest from pytest import raises @@ -78,7 +79,7 @@ class Repository(Document): tags = Keyword() @classmethod - def search(cls): + def search(cls, using: Any = None, index: Any = None) -> Any: return super(Repository, cls).search().filter("term", commit_repo="repo") class Index: @@ -131,7 +132,7 @@ class Index: name = "test-serialization" -def test_serialization(write_client): +def test_serialization(write_client: Any) -> None: SerializationDoc.init() write_client.index( index="test-serialization", @@ -161,7 +162,7 @@ def test_serialization(write_client): } -def test_nested_inner_hits_are_wrapped_properly(pull_request) -> None: +def test_nested_inner_hits_are_wrapped_properly(pull_request: Any) -> None: history_query = Q( "nested", path="comments.history", @@ -189,7 +190,7 @@ def test_nested_inner_hits_are_wrapped_properly(pull_request) -> None: assert "score" in history.meta -def test_nested_inner_hits_are_deserialized_properly(pull_request) -> None: +def test_nested_inner_hits_are_deserialized_properly(pull_request: Any) -> None: s = PullRequest.search().query( "nested", inner_hits={}, @@ -204,7 +205,7 @@ def test_nested_inner_hits_are_deserialized_properly(pull_request) -> None: assert isinstance(pr.comments[0].created_at, datetime) -def test_nested_top_hits_are_wrapped_properly(pull_request) -> None: +def test_nested_top_hits_are_wrapped_properly(pull_request: Any) -> None: s = PullRequest.search() s.aggs.bucket("comments", "nested", path="comments").metric( "hits", "top_hits", size=1 @@ -216,7 +217,7 @@ def test_nested_top_hits_are_wrapped_properly(pull_request) -> None: assert isinstance(r.aggregations.comments.hits.hits[0], Comment) -def test_update_object_field(write_client) -> None: +def test_update_object_field(write_client: Any) -> None: Wiki.init() w = Wiki( owner=User(name="Honza Kral"), @@ -236,7 +237,7 @@ def test_update_object_field(write_client) -> None: assert w.ranked == {"test1": 0.1, "topic2": 0.2} -def test_update_script(write_client) -> None: +def test_update_script(write_client: Any) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) w.save() @@ -246,7 +247,7 @@ def test_update_script(write_client) -> None: assert w.views == 47 -def test_update_retry_on_conflict(write_client) -> None: +def test_update_retry_on_conflict(write_client: Any) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) w.save() @@ -260,8 +261,8 @@ def test_update_retry_on_conflict(write_client) -> None: assert w.views == 52 -@pytest.mark.parametrize("retry_on_conflict", [None, 0]) -def test_update_conflicting_version(write_client, retry_on_conflict) -> None: +@pytest.mark.parametrize("retry_on_conflict", [None, 0]) # type: ignore +def test_update_conflicting_version(write_client: Any, retry_on_conflict: Any) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) w.save() @@ -278,7 +279,7 @@ def test_update_conflicting_version(write_client, retry_on_conflict) -> None: ) -def test_save_and_update_return_doc_meta(write_client) -> None: +def test_save_and_update_return_doc_meta(write_client: Any) -> None: Wiki.init() w = Wiki(owner=User(name="Honza Kral"), _id="opensearch-py", views=42) resp = w.save(return_doc_meta=True) @@ -302,31 +303,33 @@ def test_save_and_update_return_doc_meta(write_client) -> None: assert resp.keys().__contains__("_version") -def test_init(write_client) -> None: +def test_init(write_client: Any) -> None: Repository.init(index="test-git") assert write_client.indices.exists(index="test-git") -def test_get_raises_404_on_index_missing(data_client) -> None: +def test_get_raises_404_on_index_missing(data_client: Any) -> None: with raises(NotFoundError): Repository.get("opensearch-dsl-php", index="not-there") -def test_get_raises_404_on_non_existent_id(data_client) -> None: +def test_get_raises_404_on_non_existent_id(data_client: Any) -> None: with raises(NotFoundError): Repository.get("opensearch-dsl-php") -def test_get_returns_none_if_404_ignored(data_client) -> None: +def test_get_returns_none_if_404_ignored(data_client: Any) -> None: assert None is Repository.get("opensearch-dsl-php", ignore=404) -def test_get_returns_none_if_404_ignored_and_index_doesnt_exist(data_client) -> None: +def test_get_returns_none_if_404_ignored_and_index_doesnt_exist( + data_client: Any, +) -> None: assert None is Repository.get("42", index="not-there", ignore=404) -def test_get(data_client) -> None: +def test_get(data_client: Any) -> None: opensearch_repo = Repository.get("opensearch-py") assert isinstance(opensearch_repo, Repository) @@ -334,15 +337,15 @@ def test_get(data_client) -> None: assert datetime(2014, 3, 3) == opensearch_repo.created_at -def test_exists_return_true(data_client) -> None: +def test_exists_return_true(data_client: Any) -> None: assert Repository.exists("opensearch-py") -def test_exists_false(data_client) -> None: +def test_exists_false(data_client: Any) -> None: assert not Repository.exists("opensearch-dsl-php") -def test_get_with_tz_date(data_client) -> None: +def test_get_with_tz_date(data_client: Any) -> None: first_commit = Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="opensearch-py" ) @@ -354,7 +357,7 @@ def test_get_with_tz_date(data_client) -> None: ) -def test_save_with_tz_date(data_client) -> None: +def test_save_with_tz_date(data_client: Any) -> None: tzinfo = timezone("Europe/Prague") first_commit = Commit.get( id="3ca6e1e73a071a705b4babd2f581c91a2a3e5037", routing="opensearch-py" @@ -381,7 +384,7 @@ def test_save_with_tz_date(data_client) -> None: ] -def test_mget(data_client) -> None: +def test_mget(data_client: Any) -> None: commits = Commit.mget(COMMIT_DOCS_WITH_MISSING) assert commits[0] is None assert commits[1].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" @@ -389,23 +392,23 @@ def test_mget(data_client) -> None: assert commits[3].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" -def test_mget_raises_exception_when_missing_param_is_invalid(data_client) -> None: +def test_mget_raises_exception_when_missing_param_is_invalid(data_client: Any) -> None: with raises(ValueError): Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raj") -def test_mget_raises_404_when_missing_param_is_raise(data_client) -> None: +def test_mget_raises_404_when_missing_param_is_raise(data_client: Any) -> None: with raises(NotFoundError): Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="raise") -def test_mget_ignores_missing_docs_when_missing_param_is_skip(data_client) -> None: +def test_mget_ignores_missing_docs_when_missing_param_is_skip(data_client: Any) -> None: commits = Commit.mget(COMMIT_DOCS_WITH_MISSING, missing="skip") assert commits[0].meta.id == "3ca6e1e73a071a705b4babd2f581c91a2a3e5037" assert commits[1].meta.id == "eb3e543323f189fd7b698e66295427204fff5755" -def test_update_works_from_search_response(data_client) -> None: +def test_update_works_from_search_response(data_client: Any) -> None: opensearch_repo = Repository.search().execute()[0] opensearch_repo.update(owner={"other_name": "opensearchpy"}) @@ -416,7 +419,7 @@ def test_update_works_from_search_response(data_client) -> None: assert "opensearch" == new_version.owner.name -def test_update(data_client) -> None: +def test_update(data_client: Any) -> None: opensearch_repo = Repository.get("opensearch-py") v = opensearch_repo.meta.version @@ -440,7 +443,7 @@ def test_update(data_client) -> None: assert "primary_term" in new_version.meta -def test_save_updates_existing_doc(data_client) -> None: +def test_save_updates_existing_doc(data_client: Any) -> None: opensearch_repo = Repository.get("opensearch-py") opensearch_repo.new_field = "testing-save" @@ -453,7 +456,7 @@ def test_save_updates_existing_doc(data_client) -> None: assert new_repo["_seq_no"] == opensearch_repo.meta.seq_no -def test_save_automatically_uses_seq_no_and_primary_term(data_client) -> None: +def test_save_automatically_uses_seq_no_and_primary_term(data_client: Any) -> None: opensearch_repo = Repository.get("opensearch-py") opensearch_repo.meta.seq_no += 1 @@ -461,7 +464,7 @@ def test_save_automatically_uses_seq_no_and_primary_term(data_client) -> None: opensearch_repo.save() -def test_delete_automatically_uses_seq_no_and_primary_term(data_client) -> None: +def test_delete_automatically_uses_seq_no_and_primary_term(data_client: Any) -> None: opensearch_repo = Repository.get("opensearch-py") opensearch_repo.meta.seq_no += 1 @@ -469,13 +472,13 @@ def test_delete_automatically_uses_seq_no_and_primary_term(data_client) -> None: opensearch_repo.delete() -def assert_doc_equals(expected, actual) -> None: +def assert_doc_equals(expected: Any, actual: Any) -> None: for f in expected: assert f in actual assert actual[f] == expected[f] -def test_can_save_to_different_index(write_client): +def test_can_save_to_different_index(write_client: Any) -> None: test_repo = Repository(description="testing", meta={"id": 42}) assert test_repo.save(index="test-document") @@ -490,7 +493,7 @@ def test_can_save_to_different_index(write_client): ) -def test_save_without_skip_empty_will_include_empty_fields(write_client) -> None: +def test_save_without_skip_empty_will_include_empty_fields(write_client: Any) -> None: test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) assert test_repo.save(index="test-document", skip_empty=False) @@ -505,7 +508,7 @@ def test_save_without_skip_empty_will_include_empty_fields(write_client) -> None ) -def test_delete(write_client) -> None: +def test_delete(write_client: Any) -> None: write_client.create( index="test-document", id="opensearch-py", @@ -526,11 +529,11 @@ def test_delete(write_client) -> None: ) -def test_search(data_client) -> None: +def test_search(data_client: Any) -> None: assert Repository.search().count() == 1 -def test_search_returns_proper_doc_classes(data_client) -> None: +def test_search_returns_proper_doc_classes(data_client: Any) -> None: result = Repository.search().execute() opensearch_repo = result.hits[0] @@ -539,11 +542,13 @@ def test_search_returns_proper_doc_classes(data_client) -> None: assert opensearch_repo.owner.name == "opensearch" -def test_refresh_mapping(data_client) -> None: +def test_refresh_mapping(data_client: Any) -> None: class Commit(Document): class Index: name = "git" + _index: Any + Commit._index.load_mappings() assert "stats" in Commit._index._mapping @@ -553,7 +558,7 @@ class Index: assert isinstance(Commit._index._mapping["committed_date"], Date) -def test_highlight_in_meta(data_client) -> None: +def test_highlight_in_meta(data_client: Any) -> None: commit = ( Commit.search() .query("match", description="inverting") diff --git a/test_opensearchpy/test_server/test_helpers/test_faceted_search.py b/test_opensearchpy/test_server/test_helpers/test_faceted_search.py index 4656d4b2..38dd40cd 100644 --- a/test_opensearchpy/test_server/test_helpers/test_faceted_search.py +++ b/test_opensearchpy/test_server/test_helpers/test_faceted_search.py @@ -26,6 +26,7 @@ # under the License. from datetime import datetime +from typing import Any import pytest @@ -66,8 +67,8 @@ class MetricSearch(FacetedSearch): } -@pytest.fixture(scope="session") -def commit_search_cls(opensearch_version): +@pytest.fixture(scope="session") # type: ignore +def commit_search_cls(opensearch_version: Any) -> Any: interval_kwargs = {"fixed_interval": "1d"} class CommitSearch(FacetedSearch): @@ -91,8 +92,8 @@ class CommitSearch(FacetedSearch): return CommitSearch -@pytest.fixture(scope="session") -def repo_search_cls(opensearch_version): +@pytest.fixture(scope="session") # type: ignore +def repo_search_cls(opensearch_version: Any) -> Any: interval_type = "calendar_interval" class RepoSearch(FacetedSearch): @@ -105,15 +106,15 @@ class RepoSearch(FacetedSearch): ), } - def search(self): + def search(self) -> Any: s = super(RepoSearch, self).search() return s.filter("term", commit_repo="repo") return RepoSearch -@pytest.fixture(scope="session") -def pr_search_cls(opensearch_version): +@pytest.fixture(scope="session") # type: ignore +def pr_search_cls(opensearch_version: Any) -> Any: interval_type = "calendar_interval" class PRSearch(FacetedSearch): @@ -131,7 +132,7 @@ class PRSearch(FacetedSearch): return PRSearch -def test_facet_with_custom_metric(data_client) -> None: +def test_facet_with_custom_metric(data_client: Any) -> None: ms = MetricSearch() r = ms.execute() @@ -140,7 +141,7 @@ def test_facet_with_custom_metric(data_client) -> None: assert dates[0] == 1399038439000 -def test_nested_facet(pull_request, pr_search_cls) -> None: +def test_nested_facet(pull_request: Any, pr_search_cls: Any) -> None: prs = pr_search_cls() r = prs.execute() @@ -148,7 +149,7 @@ def test_nested_facet(pull_request, pr_search_cls) -> None: assert [(datetime(2018, 1, 1, 0, 0), 1, False)] == r.facets.comments -def test_nested_facet_with_filter(pull_request, pr_search_cls) -> None: +def test_nested_facet_with_filter(pull_request: Any, pr_search_cls: Any) -> None: prs = pr_search_cls(filters={"comments": datetime(2018, 1, 1, 0, 0)}) r = prs.execute() @@ -160,7 +161,7 @@ def test_nested_facet_with_filter(pull_request, pr_search_cls) -> None: assert not r.hits -def test_datehistogram_facet(data_client, repo_search_cls) -> None: +def test_datehistogram_facet(data_client: Any, repo_search_cls: Any) -> None: rs = repo_search_cls() r = rs.execute() @@ -168,7 +169,7 @@ def test_datehistogram_facet(data_client, repo_search_cls) -> None: assert [(datetime(2014, 3, 1, 0, 0), 1, False)] == r.facets.created -def test_boolean_facet(data_client, repo_search_cls) -> None: +def test_boolean_facet(data_client: Any, repo_search_cls: Any) -> None: rs = repo_search_cls() r = rs.execute() @@ -179,7 +180,7 @@ def test_boolean_facet(data_client, repo_search_cls) -> None: def test_empty_search_finds_everything( - data_client, opensearch_version, commit_search_cls + data_client: Any, opensearch_version: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls() r = cs.execute() @@ -225,7 +226,7 @@ def test_empty_search_finds_everything( def test_term_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls + data_client: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls(filters={"files": "test_opensearchpy/test_dsl"}) @@ -271,7 +272,7 @@ def test_term_filters_are_shown_as_selected_and_data_is_filtered( def test_range_filters_are_shown_as_selected_and_data_is_filtered( - data_client, commit_search_cls + data_client: Any, commit_search_cls: Any ) -> None: cs = commit_search_cls(filters={"deletions": "better"}) @@ -280,7 +281,7 @@ def test_range_filters_are_shown_as_selected_and_data_is_filtered( assert 19 == r.hits.total.value -def test_pagination(data_client, commit_search_cls) -> None: +def test_pagination(data_client: Any, commit_search_cls: Any) -> None: cs = commit_search_cls() cs = cs[0:20] diff --git a/test_opensearchpy/test_server/test_helpers/test_index.py b/test_opensearchpy/test_server/test_helpers/test_index.py index 8593459c..71f0501a 100644 --- a/test_opensearchpy/test_server/test_helpers/test_index.py +++ b/test_opensearchpy/test_server/test_helpers/test_index.py @@ -25,6 +25,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from opensearchpy import Date, Document, Index, IndexTemplate, Text from opensearchpy.helpers import analysis @@ -34,7 +36,7 @@ class Post(Document): published_from = Date() -def test_index_template_works(write_client) -> None: +def test_index_template_works(write_client: Any) -> None: it = IndexTemplate("test-template", "test-*") it.document(Post) it.settings(number_of_replicas=0, number_of_shards=1) @@ -55,7 +57,7 @@ def test_index_template_works(write_client) -> None: } == write_client.indices.get_mapping(index="test-blog") -def test_index_can_be_saved_even_with_settings(write_client) -> None: +def test_index_can_be_saved_even_with_settings(write_client: Any) -> None: i = Index("test-blog", using=write_client) i.settings(number_of_shards=3, number_of_replicas=0) i.save() @@ -67,12 +69,12 @@ def test_index_can_be_saved_even_with_settings(write_client) -> None: ) -def test_index_exists(data_client) -> None: +def test_index_exists(data_client: Any) -> None: assert Index("git").exists() assert not Index("not-there").exists() -def test_index_can_be_created_with_settings_and_mappings(write_client) -> None: +def test_index_can_be_created_with_settings_and_mappings(write_client: Any) -> None: i = Index("test-blog", using=write_client) i.document(Post) i.settings(number_of_replicas=0, number_of_shards=1) @@ -97,7 +99,7 @@ def test_index_can_be_created_with_settings_and_mappings(write_client) -> None: } -def test_delete(write_client) -> None: +def test_delete(write_client: Any) -> None: write_client.indices.create( index="test-index", body={"settings": {"number_of_replicas": 0, "number_of_shards": 1}}, @@ -108,7 +110,7 @@ def test_delete(write_client) -> None: assert not write_client.indices.exists(index="test-index") -def test_multiple_indices_with_same_doc_type_work(write_client) -> None: +def test_multiple_indices_with_same_doc_type_work(write_client: Any) -> None: i1 = Index("test-index-1", using=write_client) i2 = Index("test-index-2", using=write_client) @@ -116,8 +118,8 @@ def test_multiple_indices_with_same_doc_type_work(write_client) -> None: i.document(Post) i.create() - for i in ("test-index-1", "test-index-2"): - settings = write_client.indices.get_settings(index=i) - assert settings[i]["settings"]["index"]["analysis"] == { + for j in ("test-index-1", "test-index-2"): + settings = write_client.indices.get_settings(index=j) + assert settings[j]["settings"]["index"]["analysis"] == { "analyzer": {"my_analyzer": {"type": "custom", "tokenizer": "keyword"}} } diff --git a/test_opensearchpy/test_server/test_helpers/test_mapping.py b/test_opensearchpy/test_server/test_helpers/test_mapping.py index 50a80dea..722a249e 100644 --- a/test_opensearchpy/test_server/test_helpers/test_mapping.py +++ b/test_opensearchpy/test_server/test_helpers/test_mapping.py @@ -25,13 +25,15 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from pytest import raises from opensearchpy import exceptions from opensearchpy.helpers import analysis, mapping -def test_mapping_saved_into_opensearch(write_client) -> None: +def test_mapping_saved_into_opensearch(write_client: Any) -> None: m = mapping.Mapping() m.field( "name", "text", analyzer=analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -52,7 +54,7 @@ def test_mapping_saved_into_opensearch(write_client) -> None: def test_mapping_saved_into_opensearch_when_index_already_exists_closed( - write_client, + write_client: Any, ) -> None: m = mapping.Mapping() m.field( @@ -77,7 +79,7 @@ def test_mapping_saved_into_opensearch_when_index_already_exists_closed( def test_mapping_saved_into_opensearch_when_index_already_exists_with_analysis( - write_client, + write_client: Any, ) -> None: m = mapping.Mapping() analyzer = analysis.analyzer("my_analyzer", tokenizer="keyword") @@ -107,7 +109,7 @@ def test_mapping_saved_into_opensearch_when_index_already_exists_with_analysis( } == write_client.indices.get_mapping(index="test-mapping") -def test_mapping_gets_updated_from_opensearch(write_client): +def test_mapping_gets_updated_from_opensearch(write_client: Any) -> None: write_client.indices.create( index="test-mapping", body={ diff --git a/test_opensearchpy/test_server/test_helpers/test_search.py b/test_opensearchpy/test_server/test_helpers/test_search.py index 5e45645a..4fb00597 100644 --- a/test_opensearchpy/test_server/test_helpers/test_search.py +++ b/test_opensearchpy/test_server/test_helpers/test_search.py @@ -27,6 +27,8 @@ from __future__ import unicode_literals +from typing import Any + from pytest import raises from opensearchpy import ( @@ -50,7 +52,7 @@ class Repository(Document): tags = Keyword() @classmethod - def search(cls): + def search(cls, using: Any = None, index: Any = None) -> Any: return super(Repository, cls).search().filter("term", commit_repo="repo") class Index: @@ -62,7 +64,7 @@ class Index: name = "flat-git" -def test_filters_aggregation_buckets_are_accessible(data_client) -> None: +def test_filters_aggregation_buckets_are_accessible(data_client: Any) -> None: has_tests_query = Q("term", files="test_opensearchpy/test_dsl") s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").bucket( @@ -83,7 +85,7 @@ def test_filters_aggregation_buckets_are_accessible(data_client) -> None: ) -def test_top_hits_are_wrapped_in_response(data_client) -> None: +def test_top_hits_are_wrapped_in_response(data_client: Any) -> None: s = Commit.search()[0:0] s.aggs.bucket("top_authors", "terms", field="author.name.raw").metric( "top_commits", "top_hits", size=5 @@ -99,7 +101,7 @@ def test_top_hits_are_wrapped_in_response(data_client) -> None: assert isinstance(hits[0], Commit) -def test_inner_hits_are_wrapped_in_response(data_client) -> None: +def test_inner_hits_are_wrapped_in_response(data_client: Any) -> None: s = Search(index="git")[0:1].query( "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") ) @@ -110,7 +112,7 @@ def test_inner_hits_are_wrapped_in_response(data_client) -> None: assert repr(commit.meta.inner_hits.repo[0]).startswith(" None: +def test_scan_respects_doc_types(data_client: Any) -> None: repos = list(Repository.search().scan()) assert 1 == len(repos) @@ -118,7 +120,7 @@ def test_scan_respects_doc_types(data_client) -> None: assert repos[0].organization == "opensearch" -def test_scan_iterates_through_all_docs(data_client) -> None: +def test_scan_iterates_through_all_docs(data_client: Any) -> None: s = Search(index="flat-git") commits = list(s.scan()) @@ -127,7 +129,7 @@ def test_scan_iterates_through_all_docs(data_client) -> None: assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} -def test_response_is_cached(data_client) -> None: +def test_response_is_cached(data_client: Any) -> None: s = Repository.search() repos = list(s) @@ -135,7 +137,7 @@ def test_response_is_cached(data_client) -> None: assert s._response.hits == repos -def test_multi_search(data_client) -> None: +def test_multi_search(data_client: Any) -> None: s1 = Repository.search() s2 = Search(index="flat-git") @@ -152,7 +154,7 @@ def test_multi_search(data_client) -> None: assert r2._search is s2 -def test_multi_missing(data_client) -> None: +def test_multi_missing(data_client: Any) -> None: s1 = Repository.search() s2 = Search(index="flat-git") s3 = Search(index="does_not_exist") @@ -175,7 +177,7 @@ def test_multi_missing(data_client) -> None: assert r3 is None -def test_raw_subfield_can_be_used_in_aggs(data_client) -> None: +def test_raw_subfield_can_be_used_in_aggs(data_client: Any) -> None: s = Search(index="git")[0:0] s.aggs.bucket("authors", "terms", field="author.name.raw", size=1) diff --git a/test_opensearchpy/test_server/test_helpers/test_update_by_query.py b/test_opensearchpy/test_server/test_helpers/test_update_by_query.py index fb46e956..dfc4d250 100644 --- a/test_opensearchpy/test_server/test_helpers/test_update_by_query.py +++ b/test_opensearchpy/test_server/test_helpers/test_update_by_query.py @@ -25,11 +25,13 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + from opensearchpy.helpers.search import Q from opensearchpy.helpers.update_by_query import UpdateByQuery -def test_update_by_query_no_script(write_client, setup_ubq_tests) -> None: +def test_update_by_query_no_script(write_client: Any, setup_ubq_tests: Any) -> None: index = setup_ubq_tests ubq = ( @@ -48,7 +50,7 @@ def test_update_by_query_no_script(write_client, setup_ubq_tests) -> None: assert response.success() -def test_update_by_query_with_script(write_client, setup_ubq_tests) -> None: +def test_update_by_query_with_script(write_client: Any, setup_ubq_tests: Any) -> None: index = setup_ubq_tests ubq = ( @@ -65,7 +67,7 @@ def test_update_by_query_with_script(write_client, setup_ubq_tests) -> None: assert response.version_conflicts == 0 -def test_delete_by_query_with_script(write_client, setup_ubq_tests) -> None: +def test_delete_by_query_with_script(write_client: Any, setup_ubq_tests: Any) -> None: index = setup_ubq_tests ubq = ( diff --git a/test_opensearchpy/test_server/test_plugins/test_alerting.py b/test_opensearchpy/test_server/test_plugins/test_alerting.py index d127edb1..fe3ee80a 100644 --- a/test_opensearchpy/test_server/test_plugins/test_alerting.py +++ b/test_opensearchpy/test_server/test_plugins/test_alerting.py @@ -13,7 +13,7 @@ import unittest -from opensearchpy.helpers.test import OPENSEARCH_VERSION +from opensearchpy.helpers.test import OPENSEARCH_VERSION # type: ignore from .. import OpenSearchTestCase @@ -23,7 +23,7 @@ class TestAlertingPlugin(OpenSearchTestCase): (OPENSEARCH_VERSION) and (OPENSEARCH_VERSION < (2, 0, 0)), "Plugin not supported for opensearch version", ) - def test_create_destination(self): + def test_create_destination(self) -> None: # Test to create alert destination dummy_destination = { "name": "my-destination", @@ -54,7 +54,7 @@ def test_get_destination(self) -> None: (OPENSEARCH_VERSION) and (OPENSEARCH_VERSION < (2, 0, 0)), "Plugin not supported for opensearch version", ) - def test_create_monitor(self): + def test_create_monitor(self) -> None: # Create a dummy destination self.test_create_destination() diff --git a/test_opensearchpy/test_server/test_rest_api_spec.py b/test_opensearchpy/test_server/test_rest_api_spec.py index ba16d044..dc3d252c 100644 --- a/test_opensearchpy/test_server/test_rest_api_spec.py +++ b/test_opensearchpy/test_server/test_rest_api_spec.py @@ -36,6 +36,7 @@ import re import warnings import zipfile +from typing import Any import pytest import urllib3 @@ -44,7 +45,7 @@ from opensearchpy import OpenSearchWarning, TransportError from opensearchpy.client.utils import _base64_auth_header from opensearchpy.compat import string_types -from opensearchpy.helpers.test import _get_version +from opensearchpy.helpers.test import _get_version # type: ignore from . import get_client @@ -142,23 +143,23 @@ class YamlRunner: - def __init__(self, client) -> None: + def __init__(self, client: Any) -> None: self.client = client - self.last_response = None + self.last_response: Any = None - self._run_code = None - self._setup_code = None - self._teardown_code = None - self._state = {} + self._run_code: Any = None + self._setup_code: Any = None + self._teardown_code: Any = None + self._state: Any = {} - def use_spec(self, test_spec) -> None: + def use_spec(self, test_spec: Any) -> None: self._setup_code = test_spec.pop("setup", None) self._run_code = test_spec.pop("run", None) self._teardown_code = test_spec.pop("teardown", None) - def setup(self): + def setup(self) -> Any: # Pull skips from individual tests to not do unnecessary setup. - skip_code = [] + skip_code: Any = [] for action in self._run_code: assert len(action) == 1 action_type, _ = list(action.items())[0] @@ -174,12 +175,12 @@ def setup(self): if self._setup_code: self.run_code(self._setup_code) - def teardown(self) -> None: + def teardown(self) -> Any: if self._teardown_code: self.section("teardown") self.run_code(self._teardown_code) - def opensearch_version(self): + def opensearch_version(self) -> Any: global OPENSEARCH_VERSION if OPENSEARCH_VERSION is None: version_string = (self.client.info())["version"]["number"] @@ -189,10 +190,10 @@ def opensearch_version(self): OPENSEARCH_VERSION = tuple(int(v) if v.isdigit() else 99 for v in version) return OPENSEARCH_VERSION - def section(self, name) -> None: + def section(self, name: str) -> None: print(("=" * 10) + " " + name + " " + ("=" * 10)) - def run(self) -> None: + def run(self) -> Any: try: self.setup() self.section("test") @@ -203,7 +204,7 @@ def run(self) -> None: except Exception: pass - def run_code(self, test) -> None: + def run_code(self, test: Any) -> Any: """Execute an instruction based on its type.""" for action in test: assert len(action) == 1 @@ -215,7 +216,7 @@ def run_code(self, test) -> None: else: raise RuntimeError("Invalid action type %r" % (action_type,)) - def run_do(self, action) -> None: + def run_do(self, action: Any) -> Any: api = self.client headers = action.pop("headers", None) catch = action.pop("catch", None) @@ -267,7 +268,7 @@ def run_do(self, action) -> None: # Filter out warnings raised by other components. caught_warnings = [ - str(w.message) + str(w.message) # type: ignore for w in caught_warnings if w.category == OpenSearchWarning and str(w.message) not in allowed_warnings @@ -275,13 +276,13 @@ def run_do(self, action) -> None: # Sorting removes the issue with order raised. We only care about # if all warnings are raised in the single API call. - if warn and sorted(warn) != sorted(caught_warnings): + if warn and sorted(warn) != sorted(caught_warnings): # type: ignore raise AssertionError( "Expected warnings not equal to actual warnings: expected=%r actual=%r" % (warn, caught_warnings) ) - def run_catch(self, catch, exception) -> None: + def run_catch(self, catch: Any, exception: Any) -> None: if catch == "param": assert isinstance(exception, TypeError) return @@ -296,7 +297,7 @@ def run_catch(self, catch, exception) -> None: ) is not None self.last_response = exception.info - def run_skip(self, skip) -> None: + def run_skip(self, skip: Any) -> Any: global IMPLEMENTED_FEATURES if "features" in skip: @@ -318,32 +319,32 @@ def run_skip(self, skip) -> None: if min_version <= (self.opensearch_version()) <= max_version: pytest.skip(reason) - def run_gt(self, action) -> None: + def run_gt(self, action: Any) -> None: for key, value in action.items(): value = self._resolve(value) assert self._lookup(key) > value - def run_gte(self, action) -> None: + def run_gte(self, action: Any) -> None: for key, value in action.items(): value = self._resolve(value) assert self._lookup(key) >= value - def run_lt(self, action) -> None: + def run_lt(self, action: Any) -> None: for key, value in action.items(): value = self._resolve(value) assert self._lookup(key) < value - def run_lte(self, action) -> None: + def run_lte(self, action: Any) -> None: for key, value in action.items(): value = self._resolve(value) assert self._lookup(key) <= value - def run_set(self, action) -> None: + def run_set(self, action: Any) -> None: for key, value in action.items(): value = self._resolve(value) self._state[value] = self._lookup(key) - def run_is_false(self, action) -> None: + def run_is_false(self, action: Any) -> None: try: value = self._lookup(action) except AssertionError: @@ -351,23 +352,23 @@ def run_is_false(self, action) -> None: else: assert value in FALSEY_VALUES - def run_is_true(self, action) -> None: + def run_is_true(self, action: Any) -> None: value = self._lookup(action) assert value not in FALSEY_VALUES - def run_length(self, action) -> None: + def run_length(self, action: Any) -> None: for path, expected in action.items(): value = self._lookup(path) expected = self._resolve(expected) assert expected == len(value) - def run_match(self, action) -> None: + def run_match(self, action: Any) -> None: for path, expected in action.items(): value = self._lookup(path) expected = self._resolve(expected) if ( - isinstance(expected, string_types) + isinstance(expected, str) and expected.startswith("/") and expected.endswith("/") ): @@ -379,7 +380,7 @@ def run_match(self, action) -> None: else: self._assert_match_equals(value, expected) - def run_contains(self, action) -> None: + def run_contains(self, action: Any) -> None: for path, expected in action.items(): value = self._lookup(path) # list[dict[str,str]] is returned expected = self._resolve(expected) # dict[str, str] @@ -387,7 +388,7 @@ def run_contains(self, action) -> None: if expected not in value: raise AssertionError("%s is not contained by %s" % (expected, value)) - def run_transform_and_set(self, action) -> None: + def run_transform_and_set(self, action: Any) -> None: for key, value in action.items(): # Convert #base64EncodeCredentials(id,api_key) to ["id", "api_key"] if "#base64EncodeCredentials" in value: @@ -397,7 +398,7 @@ def run_transform_and_set(self, action) -> None: (self._lookup(value[0]), self._lookup(value[1])) ) - def _resolve(self, value): + def _resolve(self, value: Any) -> Any: # resolve variables if isinstance(value, string_types) and "$" in value: for k, v in self._state.items(): @@ -422,12 +423,13 @@ def _resolve(self, value): value = list(map(self._resolve, value)) return value - def _lookup(self, path): + def _lookup(self, path: str) -> Any: # fetch the possibly nested value from last_response - value = self.last_response + value: Any = self.last_response if path == "$body": return value path = path.replace(r"\.", "\1") + step: Any for step in path.split("."): if not step: continue @@ -449,10 +451,10 @@ def _lookup(self, path): value = value[step] return value - def _feature_enabled(self, name) -> bool: + def _feature_enabled(self, name: str) -> Any: return False - def _assert_match_equals(self, a, b) -> None: + def _assert_match_equals(self, a: Any, b: Any) -> None: # Handle for large floating points with 'E' if isinstance(b, string_types) and isinstance(a, float) and "e" in repr(a): a = repr(a).replace("e+", "E") @@ -460,8 +462,8 @@ def _assert_match_equals(self, a, b) -> None: assert a == b, "%r does not match %r" % (a, b) -@pytest.fixture(scope="function") -def sync_runner(sync_client): +@pytest.fixture(scope="function") # type: ignore +def sync_runner(sync_client: Any) -> Any: return YamlRunner(sync_client) @@ -532,8 +534,8 @@ def sync_runner(sync_client): if not RUN_ASYNC_REST_API_TESTS: - @pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) - def test_rest_api_spec(test_spec, sync_runner) -> None: + @pytest.mark.parametrize("test_spec", YAML_TEST_SPECS) # type: ignore + def test_rest_api_spec(test_spec: Any, sync_runner: Any) -> None: if test_spec.get("skip", False): pytest.skip("Manually skipped in 'SKIP_TESTS'") sync_runner.use_spec(test_spec) diff --git a/test_opensearchpy/test_server_secured/test_clients.py b/test_opensearchpy/test_server_secured/test_clients.py index 94684ffb..f5aef284 100644 --- a/test_opensearchpy/test_server_secured/test_clients.py +++ b/test_opensearchpy/test_server_secured/test_clients.py @@ -11,7 +11,7 @@ from unittest import TestCase from opensearchpy import OpenSearch -from opensearchpy.helpers.test import OPENSEARCH_URL +from opensearchpy.helpers.test import OPENSEARCH_URL # type: ignore class TestSecurity(TestCase): diff --git a/test_opensearchpy/test_server_secured/test_security_plugin.py b/test_opensearchpy/test_server_secured/test_security_plugin.py index 5c719953..5c309580 100644 --- a/test_opensearchpy/test_server_secured/test_security_plugin.py +++ b/test_opensearchpy/test_server_secured/test_security_plugin.py @@ -15,7 +15,7 @@ from opensearchpy.connection.connections import add_connection from opensearchpy.exceptions import NotFoundError -from opensearchpy.helpers.test import get_test_client +from opensearchpy.helpers.test import get_test_client # type: ignore class TestSecurityPlugin(TestCase): @@ -114,7 +114,7 @@ def test_create_user_with_body_param_empty(self) -> None: else: assert False - def test_create_user_with_role(self): + def test_create_user_with_role(self) -> None: self.test_create_role() # Test to create user diff --git a/test_opensearchpy/test_transport.py b/test_opensearchpy/test_transport.py index a69a7cf0..dc1a8f9e 100644 --- a/test_opensearchpy/test_transport.py +++ b/test_opensearchpy/test_transport.py @@ -30,6 +30,7 @@ import json import time +from typing import Any from mock import patch @@ -42,14 +43,14 @@ class DummyConnection(Connection): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: self.exception = kwargs.pop("exception", None) self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}") self.headers = kwargs.pop("headers", {}) - self.calls = [] + self.calls: Any = [] super(DummyConnection, self).__init__(**kwargs) - def perform_request(self, *args, **kwargs): + def perform_request(self, *args: Any, **kwargs: Any) -> Any: self.calls.append((args, kwargs)) if self.exception: raise self.exception @@ -119,20 +120,20 @@ def test_cluster_manager_only_nodes_are_ignored(self) -> None: chosen = [ i for i, node_info in enumerate(nodes) - if get_host_info(node_info, i) is not None + if get_host_info(node_info, i) is not None # type: ignore ] self.assertEqual([1, 2, 3, 4], chosen) class TestTransport(TestCase): def test_single_connection_uses_dummy_connection_pool(self) -> None: - t = Transport([{}]) - self.assertIsInstance(t.connection_pool, DummyConnectionPool) - t = Transport([{"host": "localhost"}]) - self.assertIsInstance(t.connection_pool, DummyConnectionPool) + t1: Any = Transport([{}]) + self.assertIsInstance(t1.connection_pool, DummyConnectionPool) + t2: Any = Transport([{"host": "localhost"}]) + self.assertIsInstance(t2.connection_pool, DummyConnectionPool) def test_request_timeout_extracted_from_params_and_passed(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", params={"request_timeout": 42}) self.assertEqual(1, len(t.get_connection().calls)) @@ -143,7 +144,7 @@ def test_request_timeout_extracted_from_params_and_passed(self) -> None: ) def test_timeout_extracted_from_params_and_passed(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", params={"timeout": 84}) self.assertEqual(1, len(t.get_connection().calls)) @@ -154,7 +155,7 @@ def test_timeout_extracted_from_params_and_passed(self) -> None: ) def test_opaque_id(self) -> None: - t = Transport([{}], opaque_id="app-1", connection_class=DummyConnection) + t: Any = Transport([{}], opaque_id="app-1", connection_class=DummyConnection) t.perform_request("GET", "/") self.assertEqual(1, len(t.get_connection().calls)) @@ -174,7 +175,7 @@ def test_opaque_id(self) -> None: ) def test_request_with_custom_user_agent_header(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", headers={"user-agent": "my-custom-value/1.2.3"}) self.assertEqual(1, len(t.get_connection().calls)) @@ -188,7 +189,9 @@ def test_request_with_custom_user_agent_header(self) -> None: ) def test_send_get_body_as_source(self) -> None: - t = Transport([{}], send_get_body_as="source", connection_class=DummyConnection) + t: Any = Transport( + [{}], send_get_body_as="source", connection_class=DummyConnection + ) t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) @@ -197,14 +200,16 @@ def test_send_get_body_as_source(self) -> None: ) def test_send_get_body_as_post(self) -> None: - t = Transport([{}], send_get_body_as="POST", connection_class=DummyConnection) + t: Any = Transport( + [{}], send_get_body_as="POST", connection_class=DummyConnection + ) t.perform_request("GET", "/", body={}) self.assertEqual(1, len(t.get_connection().calls)) self.assertEqual(("POST", "/", None, b"{}"), t.get_connection().calls[0][0]) def test_body_gets_encoded_into_bytes(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", body="你好") self.assertEqual(1, len(t.get_connection().calls)) @@ -214,7 +219,7 @@ def test_body_gets_encoded_into_bytes(self) -> None: ) def test_body_bytes_get_passed_untouched(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) body = b"\xe4\xbd\xa0\xe5\xa5\xbd" t.perform_request("GET", "/", body=body) @@ -222,7 +227,7 @@ def test_body_bytes_get_passed_untouched(self) -> None: self.assertEqual(("GET", "/", None, body), t.get_connection().calls[0][0]) def test_body_surrogates_replaced_encoded_into_bytes(self) -> None: - t = Transport([{}], connection_class=DummyConnection) + t: Any = Transport([{}], connection_class=DummyConnection) t.perform_request("GET", "/", body="你好\uda6a") self.assertEqual(1, len(t.get_connection().calls)) @@ -232,26 +237,26 @@ def test_body_surrogates_replaced_encoded_into_bytes(self) -> None: ) def test_kwargs_passed_on_to_connections(self) -> None: - t = Transport([{"host": "google.com"}], port=123) + t: Any = Transport([{"host": "google.com"}], port=123) self.assertEqual(1, len(t.connection_pool.connections)) self.assertEqual("http://google.com:123", t.connection_pool.connections[0].host) def test_kwargs_passed_on_to_connection_pool(self) -> None: dt = object() - t = Transport([{}, {}], dead_timeout=dt) + t: Any = Transport([{}, {}], dead_timeout=dt) self.assertIs(dt, t.connection_pool.dead_timeout) def test_custom_connection_class(self) -> None: - class MyConnection(object): - def __init__(self, **kwargs): + class MyConnection(Connection): + def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs - t = Transport([{}], connection_class=MyConnection) + t: Any = Transport([{}], connection_class=MyConnection) self.assertEqual(1, len(t.connection_pool.connections)) self.assertIsInstance(t.connection_pool.connections[0], MyConnection) def test_add_connection(self) -> None: - t = Transport([{}], randomize_hosts=False) + t: Any = Transport([{}], randomize_hosts=False) t.add_connection({"host": "google.com", "port": 1234}) self.assertEqual(2, len(t.connection_pool.connections)) @@ -260,7 +265,7 @@ def test_add_connection(self) -> None: ) def test_request_will_fail_after_X_retries(self) -> None: - t = Transport( + t: Any = Transport( [{"exception": ConnectionError("abandon ship")}], connection_class=DummyConnection, ) @@ -269,7 +274,7 @@ def test_request_will_fail_after_X_retries(self) -> None: self.assertEqual(4, len(t.get_connection().calls)) def test_failed_connection_will_be_marked_as_dead(self) -> None: - t = Transport( + t: Any = Transport( [{"exception": ConnectionError("abandon ship")}] * 2, connection_class=DummyConnection, ) @@ -279,7 +284,7 @@ def test_failed_connection_will_be_marked_as_dead(self) -> None: def test_resurrected_connection_will_be_marked_as_live_on_success(self) -> None: for method in ("GET", "HEAD"): - t = Transport([{}, {}], connection_class=DummyConnection) + t: Any = Transport([{}, {}], connection_class=DummyConnection) con1 = t.connection_pool.get_connection() con2 = t.connection_pool.get_connection() t.connection_pool.mark_dead(con1) @@ -290,7 +295,7 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self) -> None: self.assertEqual(1, len(t.connection_pool.dead_count)) def test_sniff_will_use_seed_connections(self) -> None: - t = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) + t: Any = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) t.set_connections([{"data": "invalid"}]) t.sniff_hosts() @@ -298,7 +303,7 @@ def test_sniff_will_use_seed_connections(self) -> None: self.assertEqual("http://1.1.1.1:123", t.get_connection().host) def test_sniff_on_start_fetches_and_uses_nodes_list(self) -> None: - t = Transport( + t: Any = Transport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True, @@ -307,7 +312,7 @@ def test_sniff_on_start_fetches_and_uses_nodes_list(self) -> None: self.assertEqual("http://1.1.1.1:123", t.get_connection().host) def test_sniff_on_start_ignores_sniff_timeout(self) -> None: - t = Transport( + t: Any = Transport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True, @@ -319,7 +324,7 @@ def test_sniff_on_start_ignores_sniff_timeout(self) -> None: ) def test_sniff_uses_sniff_timeout(self) -> None: - t = Transport( + t: Any = Transport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_timeout=42, @@ -330,8 +335,8 @@ def test_sniff_uses_sniff_timeout(self) -> None: t.seed_connections[0].calls[0], ) - def test_sniff_reuses_connection_instances_if_possible(self): - t = Transport( + def test_sniff_reuses_connection_instances_if_possible(self) -> None: + t: Any = Transport( [{"data": CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], connection_class=DummyConnection, randomize_hosts=False, @@ -342,8 +347,8 @@ def test_sniff_reuses_connection_instances_if_possible(self): self.assertEqual(1, len(t.connection_pool.connections)) self.assertIs(connection, t.get_connection()) - def test_sniff_on_fail_triggers_sniffing_on_fail(self): - t = Transport( + def test_sniff_on_fail_triggers_sniffing_on_fail(self) -> None: + t: Any = Transport( [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_connection_fail=True, @@ -356,9 +361,11 @@ def test_sniff_on_fail_triggers_sniffing_on_fail(self): self.assertEqual("http://1.1.1.1:123", t.get_connection().host) @patch("opensearchpy.transport.Transport.sniff_hosts") - def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): + def test_sniff_on_fail_failing_does_not_prevent_retires( + self, sniff_hosts: Any + ) -> None: sniff_hosts.side_effect = [TransportError("sniff failed")] - t = Transport( + t: Any = Transport( [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_connection_fail=True, @@ -374,7 +381,7 @@ def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): self.assertEqual(1, len(conn_data.calls)) def test_sniff_after_n_seconds(self) -> None: - t = Transport( + t: Any = Transport( [{"data": CLUSTER_NODES}], connection_class=DummyConnection, sniffer_timeout=5, @@ -394,7 +401,7 @@ def test_sniff_after_n_seconds(self) -> None: def test_sniff_7x_publish_host(self) -> None: # Test the response shaped when a 7.x node has publish_host set # and the returend data is shaped in the fqdn/ip:port format. - t = Transport( + t: Any = Transport( [{"data": CLUSTER_NODES_7x_PUBLISH_HOST}], connection_class=DummyConnection, sniff_timeout=42, diff --git a/test_opensearchpy/utils.py b/test_opensearchpy/utils.py index 5aa4983b..50682d35 100644 --- a/test_opensearchpy/utils.py +++ b/test_opensearchpy/utils.py @@ -27,11 +27,12 @@ import time +from typing import Any from opensearchpy import OpenSearch -def wipe_cluster(client) -> None: +def wipe_cluster(client: Any) -> None: """Wipes a cluster clean between test cases""" close_after_wipe = False try: @@ -59,9 +60,9 @@ def wipe_cluster(client) -> None: client.close() -def wipe_cluster_settings(client) -> None: +def wipe_cluster_settings(client: Any) -> None: settings = client.cluster.get_settings() - new_settings = {} + new_settings: Any = {} for name, value in settings.items(): if value: new_settings.setdefault(name, {}) @@ -71,7 +72,7 @@ def wipe_cluster_settings(client) -> None: client.cluster.put_settings(body=new_settings) -def wipe_snapshots(client): +def wipe_snapshots(client: Any) -> None: """Deletes all the snapshots and repositories from the cluster""" in_progress_snapshots = [] @@ -96,14 +97,14 @@ def wipe_snapshots(client): assert in_progress_snapshots == [] -def wipe_data_streams(client) -> None: +def wipe_data_streams(client: Any) -> None: try: client.indices.delete_data_stream(name="*", expand_wildcards="all") except Exception: client.indices.delete_data_stream(name="*") -def wipe_indices(client) -> None: +def wipe_indices(client: Any) -> None: client.indices.delete( index="*,-.ds-ilm-history-*", expand_wildcards="all", @@ -111,7 +112,7 @@ def wipe_indices(client) -> None: ) -def wipe_searchable_snapshot_indices(client) -> None: +def wipe_searchable_snapshot_indices(client: Any) -> None: cluster_metadata = client.cluster.state( metric="metadata", filter_path="metadata.indices.*.settings.index.store.snapshot", @@ -121,17 +122,17 @@ def wipe_searchable_snapshot_indices(client) -> None: client.indices.delete(index=index) -def wipe_slm_policies(client) -> None: +def wipe_slm_policies(client: Any) -> None: for policy in client.slm.get_lifecycle(): client.slm.delete_lifecycle(policy_id=policy["name"]) -def wipe_auto_follow_patterns(client) -> None: +def wipe_auto_follow_patterns(client: Any) -> None: for pattern in client.ccr.get_auto_follow_pattern()["patterns"]: client.ccr.delete_auto_follow_pattern(name=pattern["name"]) -def wipe_node_shutdown_metadata(client) -> None: +def wipe_node_shutdown_metadata(client: Any) -> None: shutdown_status = client.shutdown.get_node() # If response contains these two keys the feature flag isn't enabled # on this cluster so skip this step now. @@ -143,14 +144,14 @@ def wipe_node_shutdown_metadata(client) -> None: client.shutdown.delete_node(node_id=node_id) -def wipe_tasks(client) -> None: +def wipe_tasks(client: Any) -> None: tasks = client.tasks.list() for node_name, node in tasks.get("node", {}).items(): for task_id in node.get("tasks", ()): client.tasks.cancel(task_id=task_id, wait_for_completion=True) -def wait_for_pending_tasks(client, filter, timeout: int = 30) -> None: +def wait_for_pending_tasks(client: Any, filter: Any, timeout: int = 30) -> None: end_time = time.time() + timeout while time.time() < end_time: tasks = client.cat.tasks(detailed=True).split("\n") @@ -158,7 +159,7 @@ def wait_for_pending_tasks(client, filter, timeout: int = 30) -> None: break -def wait_for_pending_datafeeds_and_jobs(client, timeout: int = 30) -> None: +def wait_for_pending_datafeeds_and_jobs(client: Any, timeout: int = 30) -> None: end_time = time.time() + timeout while time.time() < end_time: if ( @@ -171,7 +172,7 @@ def wait_for_pending_datafeeds_and_jobs(client, timeout: int = 30) -> None: break -def wait_for_cluster_state_updates_to_finish(client, timeout: int = 30) -> None: +def wait_for_cluster_state_updates_to_finish(client: Any, timeout: int = 30) -> None: end_time = time.time() + timeout while time.time() < end_time: if not client.cluster.pending_tasks().get("tasks", ()): diff --git a/utils/build-dists.py b/utils/build-dists.py index b45da98e..d5517015 100644 --- a/utils/build-dists.py +++ b/utils/build-dists.py @@ -38,13 +38,14 @@ import shutil import sys import tempfile +from typing import Any base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) tmp_dir = None -@contextlib.contextmanager -def set_tmp_dir(): +@contextlib.contextmanager # type: ignore +def set_tmp_dir() -> None: # type: ignore global tmp_dir tmp_dir = tempfile.mkdtemp() yield tmp_dir @@ -52,7 +53,7 @@ def set_tmp_dir(): tmp_dir = None -def run(*argv, expect_exit_code: int = 0) -> None: +def run(*argv: Any, expect_exit_code: int = 0) -> None: global tmp_dir if tmp_dir is None: os.chdir(base_dir) @@ -70,9 +71,9 @@ def run(*argv, expect_exit_code: int = 0) -> None: exit(exit_code or 1) -def test_dist(dist) -> None: - with set_tmp_dir() as tmp_dir: - dist_name = re.match( +def test_dist(dist: Any) -> None: + with set_tmp_dir() as tmp_dir: # type: ignore + dist_name = re.match( # type: ignore r"^(opensearchpy\d*)-", os.path.basename(dist) .replace("opensearch-py", "opensearchpy") @@ -216,7 +217,7 @@ def main() -> None: # alpha/beta/rc -> aN/bN/rcN else: pre_number = re.search(r"-(a|b|rc)(?:lpha|eta|)(\d+)$", expect_version) - version = version + pre_number.group(1) + pre_number.group(2) + version = version + pre_number.group(1) + pre_number.group(2) # type: ignore expect_version = re.sub( r"(?:-(?:SNAPSHOT|alpha\d+|beta\d+|rc\d+))+$", "", expect_version diff --git a/utils/generate-api.py b/utils/generate-api.py index f53e212c..792446dd 100644 --- a/utils/generate-api.py +++ b/utils/generate-api.py @@ -37,6 +37,7 @@ from itertools import chain, groupby from operator import itemgetter from pathlib import Path +from typing import Any, Dict import black import deepmerge @@ -78,27 +79,27 @@ ) -def blacken(filename) -> None: +def blacken(filename: Any) -> None: runner = CliRunner() result = runner.invoke(black.main, [str(filename)]) assert result.exit_code == 0, result.output @lru_cache() -def is_valid_url(url): +def is_valid_url(url: str) -> bool: return 200 <= http.request("HEAD", url).status < 400 class Module: - def __init__(self, namespace) -> None: - self.namespace = namespace - self._apis = [] + def __init__(self, namespace: str) -> None: + self.namespace: Any = namespace + self._apis: Any = [] self.parse_orig() - def add(self, api) -> None: + def add(self, api: Any) -> None: self._apis.append(api) - def parse_orig(self): + def parse_orig(self) -> None: self.orders = [] self.header = "from typing import Any, Collection, Optional, Tuple, Union\n\n" @@ -129,7 +130,7 @@ def parse_orig(self): r"\n (?:async )?def ([a-z_]+)\(", content, re.MULTILINE ) - def _position(self, api): + def _position(self, api: Any) -> Any: try: return self.orders.index(api.name) except ValueError: @@ -234,12 +235,12 @@ def dump(self) -> None: f.write(file_content) @property - def filepath(self): + def filepath(self) -> Any: return CODE_ROOT / f"opensearchpy/_async/client/{self.namespace}.py" class API: - def __init__(self, namespace, name, definition) -> None: + def __init__(self, namespace: str, name: str, definition: Any) -> None: self.namespace = namespace self.name = name @@ -284,7 +285,7 @@ def __init__(self, namespace, name, definition) -> None: print(f"URL {revised_url!r}, falling back on {self.doc_url!r}") @property - def all_parts(self): + def all_parts(self) -> Dict[str, str]: parts = {} for url in self._def["url"]["paths"]: parts.update(url.get("parts", {})) @@ -309,7 +310,7 @@ def all_parts(self): dynamic, components = self.url_parts - def ind(item): + def ind(item: Any) -> Any: try: return components.index(item[0]) except ValueError: @@ -319,29 +320,29 @@ def ind(item): return parts @property - def params(self): + def params(self) -> Any: parts = self.all_parts params = self._def.get("params", {}) return chain( - ((p, parts[p]) for p in parts if parts[p]["required"]), + ((p, parts[p]) for p in parts if parts[p]["required"]), # type: ignore (("body", self.body),) if self.body else (), ( (p, parts[p]) for p in parts - if not parts[p]["required"] and p not in params + if not parts[p]["required"] and p not in params # type: ignore ), sorted(params.items(), key=lambda x: (x[0] not in parts, x[0])), ) @property - def body(self): + def body(self) -> Any: b = self._def.get("body", {}) if b: b.setdefault("required", False) return b @property - def query_params(self): + def query_params(self) -> Any: return ( k for k in sorted(self._def.get("params", {}).keys()) @@ -349,7 +350,7 @@ def query_params(self): ) @property - def all_func_params(self): + def all_func_params(self) -> Any: """Parameters that will be in the '@query_params' decorator list and parameters that will be in the function signature. This doesn't include @@ -362,14 +363,14 @@ def all_func_params(self): return params @property - def path(self): + def path(self) -> Any: return max( (path for path in self._def["url"]["paths"]), key=lambda p: len(re.findall(r"\{([^}]+)\}", p["path"])), ) @property - def method(self): + def method(self) -> Any: # To adhere to the HTTP RFC we shouldn't send # bodies in GET requests. default_method = self.path["methods"][0] @@ -382,7 +383,7 @@ def method(self): return default_method @property - def url_parts(self): + def url_parts(self) -> Any: path = self.path["path"] dynamic = "{" in path @@ -403,14 +404,14 @@ def url_parts(self): return dynamic, parts @property - def required_parts(self): + def required_parts(self) -> Any: parts = self.all_parts - required = [p for p in parts if parts[p]["required"]] + required = [p for p in parts if parts[p]["required"]] # type: ignore if self.body.get("required"): required.append("body") return required - def to_python(self): + def to_python(self) -> Any: try: t = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}") except TemplateNotFound: @@ -423,7 +424,7 @@ def to_python(self): ) -def read_modules(): +def read_modules() -> Any: modules = {} # Load the OpenAPI specification file @@ -596,8 +597,8 @@ def read_modules(): if "POST" in methods or "PUT" in methods: api.update( { - "stability": "stable", - "visibility": "public", + "stability": "stable", # type: ignore + "visibility": "public", # type: ignore "headers": { "accept": ["application/json"], "content_type": ["application/json"], @@ -607,8 +608,8 @@ def read_modules(): else: api.update( { - "stability": "stable", - "visibility": "public", + "stability": "stable", # type: ignore + "visibility": "public", # type: ignore "headers": {"accept": ["application/json"]}, } ) @@ -641,7 +642,7 @@ def read_modules(): return modules -def apply_patch(namespace, name, api): +def apply_patch(namespace: str, name: str, api: Any) -> Any: override_file_path = ( CODE_ROOT / "utils/templates/overrides" / namespace / f"{name}.json" ) @@ -652,7 +653,7 @@ def apply_patch(namespace, name, api): return api -def dump_modules(modules): +def dump_modules(modules: Any) -> None: for mod in modules.values(): mod.dump()