diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b587a75..835c403d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Removed unnecessary `# -*- coding: utf-8 -*-` headers from .py files ([#615](https://github.com/opensearch-project/opensearch-py/pull/615), [#617](https://github.com/opensearch-project/opensearch-py/pull/617)) ### Fixed - Fix KeyError when scroll return no hits ([#616](https://github.com/opensearch-project/opensearch-py/pull/616)) +- Fix reuse of `OpenSearch` using `Urllib3HttpConnection` and `AsyncOpenSearch` after calling `close` ([#639](https://github.com/opensearch-project/opensearch-py/pull/639)) ### Security ### Dependencies - Bumps `pytest-asyncio` from <=0.21.1 to <=0.23.2 diff --git a/opensearchpy/_async/http_aiohttp.py b/opensearchpy/_async/http_aiohttp.py index b1baf148..c49fd574 100644 --- a/opensearchpy/_async/http_aiohttp.py +++ b/opensearchpy/_async/http_aiohttp.py @@ -361,6 +361,7 @@ async def close(self) -> Any: """ if self.session: await self.session.close() + self.session = None async def _create_aiohttp_session(self) -> Any: """Creates an aiohttp.ClientSession(). This is delayed until diff --git a/opensearchpy/connection/http_async.py b/opensearchpy/connection/http_async.py index 468f3244..f5a4ec7c 100644 --- a/opensearchpy/connection/http_async.py +++ b/opensearchpy/connection/http_async.py @@ -277,6 +277,7 @@ async def close(self) -> Any: """ if self.session: await self.session.close() + self.session = None async def _create_aiohttp_session(self) -> Any: """Creates an aiohttp.ClientSession(). This is delayed until diff --git a/opensearchpy/connection/http_urllib3.py b/opensearchpy/connection/http_urllib3.py index 54f2a22a..ab9a1a78 100644 --- a/opensearchpy/connection/http_urllib3.py +++ b/opensearchpy/connection/http_urllib3.py @@ -214,9 +214,13 @@ def __init__( if pool_maxsize and isinstance(pool_maxsize, int): kw["maxsize"] = pool_maxsize - self.pool = pool_class( + self._urllib3_pool_factory = lambda: pool_class( self.hostname, port=self.port, timeout=self.timeout, **kw ) + self._create_urllib3_pool() + + def _create_urllib3_pool(self) -> None: + self.pool = self._urllib3_pool_factory() # type: ignore def perform_request( self, @@ -228,6 +232,10 @@ def perform_request( ignore: Collection[int] = (), headers: Optional[Mapping[str, str]] = None, ) -> Any: + if self.pool is None: + self._create_urllib3_pool() + assert self.pool is not None + url = self.url_prefix + url if params: url = "%s?%s" % (url, urlencode(params)) @@ -305,4 +313,6 @@ def close(self) -> None: """ Explicitly closes connection """ - self.pool.close() + if self.pool: + self.pool.close() + self.pool = None diff --git a/test_opensearchpy/test_async/test_server/test_clients.py b/test_opensearchpy/test_async/test_server/test_clients.py index cee6bc7b..521f0600 100644 --- a/test_opensearchpy/test_async/test_server/test_clients.py +++ b/test_opensearchpy/test_async/test_server/test_clients.py @@ -67,3 +67,15 @@ async def test_aiohttp_connection_works_without_yarl( resp = await async_client.info(pretty=True) assert isinstance(resp, dict) + + +class TestClose: + async def test_close_doesnt_break_client(self, async_client: Any) -> None: + await async_client.cluster.health() + await async_client.close() + await async_client.cluster.health() + + async def test_with_doesnt_break_client(self, async_client: Any) -> None: + for _ in range(2): + async with async_client as client: + await client.cluster.health() diff --git a/test_opensearchpy/test_server/test_clients.py b/test_opensearchpy/test_server/test_clients.py index e945b69a..a77b0f37 100644 --- a/test_opensearchpy/test_server/test_clients.py +++ b/test_opensearchpy/test_server/test_clients.py @@ -49,3 +49,15 @@ def test_bulk_works_with_bytestring_body(self) -> None: self.assertFalse(response["errors"]) self.assertEqual(1, len(response["items"])) + + +class TestClose(OpenSearchTestCase): + def test_close_doesnt_break_client(self) -> None: + self.client.cluster.health() + self.client.close() + self.client.cluster.health() + + def test_with_doesnt_break_client(self) -> None: + for _ in range(2): + with self.client as client: + client.cluster.health()