diff --git a/CHANGELOG.md b/CHANGELOG.md index 2af46a82..3354d4e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859)) ### Updated APIs ### Changed +- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866)) ### Deprecated ### Removed ### Fixed diff --git a/opensearchpy/helpers/asyncsigner.py b/opensearchpy/helpers/asyncsigner.py index 930d0081..47259033 100644 --- a/opensearchpy/helpers/asyncsigner.py +++ b/opensearchpy/helpers/asyncsigner.py @@ -8,7 +8,8 @@ # GitHub history for details. from typing import Any, Dict, Optional, Union -from urllib.parse import parse_qs, urlencode, urlparse + +from opensearchpy.helpers.signer import AWSV4Signer class AWSV4SignerAsyncAuth: @@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth: """ def __init__(self, credentials: Any, region: str, service: str = "es") -> None: - if not credentials: - raise ValueError("Credentials cannot be empty") - self.credentials = credentials - - if not region: - raise ValueError("Region cannot be empty") - self.region = region - - if not service: - raise ValueError("Service name cannot be empty") - self.service = service + self.signer = AWSV4Signer(credentials, region, service) def __call__( self, method: str, url: str, - query_string: Optional[str] = None, body: Optional[Union[str, bytes]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - return self._sign_request(method, url, query_string, body, headers) + return self._sign_request(method=method, url=url, body=body, headers=headers) def _sign_request( self, method: str, url: str, - query_string: Optional[str], body: Optional[Union[str, bytes]], headers: Optional[Dict[str, str]], ) -> Dict[str, str]: @@ -53,58 +42,10 @@ def _sign_request( :return: signed headers """ - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest - - signature_host = self._fetch_url(url, headers or dict()) - - # create an AWS request object and sign it using SigV4Auth - aws_request = AWSRequest( + updated_headers = self.signer.sign( method=method, - url=signature_host, - data=body, - ) - - # credentials objects expose access_key, secret_key and token attributes - # via @property annotations that call _refresh() on every access, - # creating a race condition if the credentials expire before secret_key - # is called but after access_key- the end result is the access_key doesn't - # correspond to the secret_key used to sign the request. To avoid this, - # get_frozen_credentials() which returns non-refreshing credentials is - # called if it exists. - credentials = ( - self.credentials.get_frozen_credentials() - if hasattr(self.credentials, "get_frozen_credentials") - and callable(self.credentials.get_frozen_credentials) - else self.credentials + url=url, + body=body, + headers=headers, ) - - sig_v4_auth = SigV4Auth(credentials, self.service, self.region) - sig_v4_auth.add_auth(aws_request) - aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request) - - # copy the headers from AWS request object into the prepared_request - return dict(aws_request.headers.items()) - - def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str: - """ - This is a util method that helps in reconstructing the request url. - :param prepared_request: unsigned request - :return: reconstructed url - """ - parsed_url = urlparse(url) - path = parsed_url.path or "/" - - # fetch the query string if present in the request - querystring = "" - if parsed_url.query: - querystring = "?" + urlencode( - parse_qs(parsed_url.query, keep_blank_values=True), doseq=True - ) - - # fetch the host information from headers - headers = {key.lower(): value for key, value in (headers or dict()).items()} - location = headers.get("host") or parsed_url.netloc - - # construct the url and return - return parsed_url.scheme + "://" + location + path + querystring + return updated_headers diff --git a/opensearchpy/helpers/signer.py b/opensearchpy/helpers/signer.py index 0258a859..b346e16a 100644 --- a/opensearchpy/helpers/signer.py +++ b/opensearchpy/helpers/signer.py @@ -7,7 +7,7 @@ # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional from urllib.parse import parse_qs, urlencode, urlparse import requests @@ -31,7 +31,9 @@ def __init__(self, credentials, region: str, service: str = "es") -> Any: # typ raise ValueError("Service name cannot be empty") self.service = service - def sign(self, method: str, url: str, body: Any) -> Dict[str, str]: + def sign( + self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: """ This method signs the request and returns headers. :param method: HTTP method @@ -43,8 +45,10 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]: from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest + signature_host = self._fetch_url(url, headers or dict()) + # create an AWS request object and sign it using SigV4Auth - aws_request = AWSRequest(method=method.upper(), url=url, data=body) + aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body) # credentials objects expose access_key, secret_key and token attributes # via @property annotations that call _refresh() on every access, @@ -69,6 +73,30 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]: return headers + @staticmethod + def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str: + """ + This is a util method that helps in reconstructing the request url. + :param prepared_request: unsigned request + :return: reconstructed url + """ + parsed_url = urlparse(url) + path = parsed_url.path or "/" + + # fetch the query string if present in the request + querystring = "" + if parsed_url.query: + querystring = "?" + urlencode( + parse_qs(parsed_url.query, keep_blank_values=True), doseq=True + ) + + # fetch the host information from headers + headers = {key.lower(): value for key, value in (headers or dict()).items()} + location = headers.get("host") or parsed_url.netloc + + # construct the url and return + return parsed_url.scheme + "://" + location + path + querystring + class RequestsAWSV4SignerAuth(requests.auth.AuthBase): """ @@ -89,40 +117,16 @@ def _sign_request(self, prepared_request): # type: ignore :return: signed request """ - prepared_request.headers.update( - self.signer.sign( - prepared_request.method, - self._fetch_url(prepared_request), - prepared_request.body, - ) + updated_headers = self.signer.sign( + method=prepared_request.method, + url=prepared_request.url, + body=prepared_request.body, + headers=prepared_request.headers, ) - return prepared_request - - def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str: - """ - This is a util method that helps in reconstructing the request url. - :param prepared_request: unsigned request - :return: reconstructed url - """ - url = urlparse(prepared_request.url) - path = url.path or "/" - - # fetch the query string if present in the request - querystring = "" - if url.query: - querystring = "?" + urlencode( - parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore - ) + prepared_request.headers.update(updated_headers) - # fetch the host information from headers - headers = { - key.lower(): value for key, value in prepared_request.headers.items() - } - location = headers.get("host") or url.netloc - - # construct the url and return - return url.scheme + "://" + location + path + querystring # type: ignore + return prepared_request # Deprecated: use RequestsAWSV4SignerAuth @@ -135,5 +139,7 @@ def __init__(self, credentials, region, service: str = "es") -> None: # type: i self.signer = AWSV4Signer(credentials, region, service) self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600 - def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]: - return self.signer.sign(method, url, body) + def __call__( + self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: + return self.signer.sign(method, url, body, headers) diff --git a/test_opensearchpy/test_async/test_signer.py b/test_opensearchpy/test_async/test_signer.py index 98109d7a..02dd43d5 100644 --- a/test_opensearchpy/test_async/test_signer.py +++ b/test_opensearchpy/test_async/test_signer.py @@ -9,7 +9,7 @@ import uuid from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from _pytest.mark.structures import MarkDecorator @@ -81,15 +81,18 @@ async def test_aws_signer_async_fetch_url_with_querystring(self) -> None: region = "us-west-2" service = "aoss" - from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth - - auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service) + from botocore.awsrequest import AWSRequest - signature_host = auth._fetch_url( - "http://localhost/?foo=bar", headers={"host": "otherhost"} - ) + from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth - assert signature_host == "http://otherhost/?foo=bar" + with patch( + "botocore.awsrequest.AWSRequest", side_effect=AWSRequest + ) as mock_aws_request: + auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service) + auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}) + mock_aws_request.assert_called_with( + method="GET", url="http://otherhost:443/?foo=bar", data=None + ) class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner): @@ -155,7 +158,6 @@ def _sign_request( self, method: str, url: str, - query_string: Optional[str] = None, body: Optional[Union[str, bytes]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: diff --git a/test_opensearchpy/test_connection/test_requests_http_connection.py b/test_opensearchpy/test_connection/test_requests_http_connection.py index a1aee810..235c4929 100644 --- a/test_opensearchpy/test_connection/test_requests_http_connection.py +++ b/test_opensearchpy/test_connection/test_requests_http_connection.py @@ -457,22 +457,27 @@ def mock_session(self) -> Any: return dummy_session - def test_aws_signer_fetch_url_with_querystring(self) -> None: + def test_aws_signer_url_with_querystring_and_custom_header(self) -> None: region = "us-west-2" import requests + from botocore.awsrequest import AWSRequest from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth - auth = RequestsAWSV4SignerAuth(self.mock_session(), region) - - prepared_request = requests.Request( - "GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"} - ).prepare() + with patch( + "botocore.awsrequest.AWSRequest", side_effect=AWSRequest + ) as mock_aws_request: - signature_host = auth._fetch_url(prepared_request) + auth = RequestsAWSV4SignerAuth(self.mock_session(), region) + prepared_request = requests.Request( + "GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"} + ).prepare() + auth(prepared_request) - assert signature_host == "http://otherhost:443/?foo=bar" + mock_aws_request.assert_called_with( + method="GET", url="http://otherhost:443/?foo=bar", data=None + ) def test_aws_signer_as_http_auth(self) -> None: region = "us-west-2" @@ -525,9 +530,11 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None: ).prepare() auth(prepared_request) self.assertEqual(mock_sign.call_count, 1) - self.assertEqual( - mock_sign.call_args[0], - ("GET", "http://localhost/?key1=value1&key2=value2", None), + mock_sign.assert_called_with( + method="GET", + url="http://localhost/?key1=value1&key2=value2", + body=None, + headers={}, ) def test_aws_signer_consitent_url(self) -> None: