Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small refactor of AWS Signer classes for both sync and async clients #866

Merged
merged 9 commits into from
Dec 3, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859))
### Updated APIs
### Changed
- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866))
### Deprecated
### Removed
### Fixed
Expand Down
77 changes: 9 additions & 68 deletions opensearchpy/helpers/asyncsigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# GitHub history for details.

from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse

from opensearchpy.helpers.signer import AWSV4Signer


class AWSV4SignerAsyncAuth:
Expand All @@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth:
"""

def __init__(self, credentials: Any, region: str, service: str = "es") -> None:
if not credentials:
raise ValueError("Credentials cannot be empty")
self.credentials = credentials

if not region:
raise ValueError("Region cannot be empty")
self.region = region

if not service:
raise ValueError("Service name cannot be empty")
self.service = service
self.signer = AWSV4Signer(credentials, region, service)

def __call__(
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
return self._sign_request(method, url, query_string, body, headers)
return self._sign_request(method=method, url=url, body=body, headers=headers)

def _sign_request(
self,
method: str,
url: str,
query_string: Optional[str],
body: Optional[Union[str, bytes]],
headers: Optional[Dict[str, str]],
) -> Dict[str, str]:
Expand All @@ -53,58 +42,10 @@ def _sign_request(
:return: signed headers
"""

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

signature_host = self._fetch_url(url, headers or dict())

# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(
updated_headers = self.signer.sign(
method=method,
url=signature_host,
data=body,
)

# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
# creating a race condition if the credentials expire before secret_key
# is called but after access_key- the end result is the access_key doesn't
# correspond to the secret_key used to sign the request. To avoid this,
# get_frozen_credentials() which returns non-refreshing credentials is
# called if it exists.
credentials = (
self.credentials.get_frozen_credentials()
if hasattr(self.credentials, "get_frozen_credentials")
and callable(self.credentials.get_frozen_credentials)
else self.credentials
url=url,
body=body,
headers=headers,
)

sig_v4_auth = SigV4Auth(credentials, self.service, self.region)
sig_v4_auth.add_auth(aws_request)
aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request)

# copy the headers from AWS request object into the prepared_request
return dict(aws_request.headers.items())

def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"

# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)

# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc

# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring
return updated_headers
78 changes: 42 additions & 36 deletions opensearchpy/helpers/signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
from urllib.parse import parse_qs, urlencode, urlparse

import requests
Expand All @@ -31,7 +31,9 @@ def __init__(self, credentials, region: str, service: str = "es") -> Any: # typ
raise ValueError("Service name cannot be empty")
self.service = service

def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
def sign(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""
This method signs the request and returns headers.
:param method: HTTP method
Expand All @@ -43,8 +45,10 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

signature_host = self._fetch_url(url, headers or dict())

# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(method=method.upper(), url=url, data=body)
aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body)

# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
Expand All @@ -69,6 +73,30 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:

return headers

@staticmethod
def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"

# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)

# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc

# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring


class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
"""
Expand All @@ -89,40 +117,16 @@ def _sign_request(self, prepared_request): # type: ignore
:return: signed request
"""

prepared_request.headers.update(
self.signer.sign(
prepared_request.method,
self._fetch_url(prepared_request),
prepared_request.body,
)
updated_headers = self.signer.sign(
method=prepared_request.method,
url=prepared_request.url,
body=prepared_request.body,
headers=prepared_request.headers,
)

return prepared_request

def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
url = urlparse(prepared_request.url)
path = url.path or "/"

# fetch the query string if present in the request
querystring = ""
if url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
)
prepared_request.headers.update(updated_headers)

# fetch the host information from headers
headers = {
key.lower(): value for key, value in prepared_request.headers.items()
}
location = headers.get("host") or url.netloc

# construct the url and return
return url.scheme + "://" + location + path + querystring # type: ignore
return prepared_request


# Deprecated: use RequestsAWSV4SignerAuth
Expand All @@ -135,5 +139,7 @@ def __init__(self, credentials, region, service: str = "es") -> None: # type: i
self.signer = AWSV4Signer(credentials, region, service)
self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600

def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]:
return self.signer.sign(method, url, body)
def __call__(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
return self.signer.sign(method, url, body, headers)
20 changes: 11 additions & 9 deletions test_opensearchpy/test_async/test_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import uuid
from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
from _pytest.mark.structures import MarkDecorator
Expand Down Expand Up @@ -81,15 +81,18 @@ async def test_aws_signer_async_fetch_url_with_querystring(self) -> None:
region = "us-west-2"
service = "aoss"

from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth

auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
from botocore.awsrequest import AWSRequest

signature_host = auth._fetch_url(
"http://localhost/?foo=bar", headers={"host": "otherhost"}
)
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth

assert signature_host == "http://otherhost/?foo=bar"
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"})
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)


class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
Expand Down Expand Up @@ -155,7 +158,6 @@ def _sign_request(
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,22 +457,27 @@ def mock_session(self) -> Any:

return dummy_session

def test_aws_signer_fetch_url_with_querystring(self) -> None:
def test_aws_signer_url_with_querystring_and_custom_header(self) -> None:
region = "us-west-2"

import requests
from botocore.awsrequest import AWSRequest

from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth

auth = RequestsAWSV4SignerAuth(self.mock_session(), region)

prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:

signature_host = auth._fetch_url(prepared_request)
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
auth(prepared_request)

assert signature_host == "http://otherhost:443/?foo=bar"
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)

def test_aws_signer_as_http_auth(self) -> None:
region = "us-west-2"
Expand Down Expand Up @@ -525,9 +530,11 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None:
).prepare()
auth(prepared_request)
self.assertEqual(mock_sign.call_count, 1)
self.assertEqual(
mock_sign.call_args[0],
("GET", "http://localhost/?key1=value1&key2=value2", None),
mock_sign.assert_called_with(
method="GET",
url="http://localhost/?key1=value1&key2=value2",
body=None,
headers={},
)

def test_aws_signer_consitent_url(self) -> None:
Expand Down
Loading