Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/refactor request headers #162

Merged
merged 13 commits into from
Oct 18, 2023
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.9.0] - 2023-10-10

### Added
- Added a content type parameter to the set stream content method in request information.

### Changed
- Added dedicated HeadersCollection class to manage request headers.

## [0.8.7] - 2023-10-05

### Added
Expand Down
2 changes: 1 addition & 1 deletion kiota_abstractions/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION: str = "0.8.7"
VERSION: str = "0.9.0"
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ async def authenticate_request(
url_parts[4] = urlencode(query)
request.url = urlunparse(url_parts)
elif self.key_location == KeyLocation.Header:
request.add_request_headers({self.parameter_name: self.api_key})
request.headers.add(self.parameter_name, self.api_key)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any, Dict

from ..headers_collection import HeadersCollection
from ..request_information import RequestInformation
from .access_token_provider import AccessTokenProvider
from .authentication_provider import AuthenticationProvider
Expand Down Expand Up @@ -36,17 +37,18 @@ async def authenticate_request(
if all(
[
additional_authentication_context, self.CLAIMS_KEY
in additional_authentication_context, self.AUTHORIZATION_HEADER in request.headers
in additional_authentication_context,
request.headers.contains(self.AUTHORIZATION_HEADER)
]
):
del request.headers[self.AUTHORIZATION_HEADER]
request.headers.remove(self.AUTHORIZATION_HEADER)

if not request.request_headers:
request.headers = {}
request.headers = HeadersCollection()

if not self.AUTHORIZATION_HEADER in request.headers:
if not request.headers.contains(self.AUTHORIZATION_HEADER):
token = await self.access_token_provider.get_authorization_token(
request.url, additional_authentication_context
)
if token:
request.add_request_headers({f'{self.AUTHORIZATION_HEADER}': f'Bearer {token}'})
request.headers.add(f'{self.AUTHORIZATION_HEADER}', f'Bearer {token}')
183 changes: 183 additions & 0 deletions kiota_abstractions/headers_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

from typing import Dict, List, Set, Union


class HeadersCollection():
"Represents a collection of request/response headers"
SINGLE_VALUE_HEADERS: Set[str] = {"content-type", "content-encoding", "content-length"}

def __init__(self) -> None:
self._headers: Dict[str, Set[str]] = {}

def try_get(self, key: str) -> Union[bool, Set[str]]:
"""Gets the header values corresponding to a specific header name.

Args:
key (str): Header key.

Returns:
Union[bool, Set[str]]: The header values for the specified header key or False.
"""
if not key:
raise ValueError("Header name cannot be null")
key = key.lower()
values = self._headers.get(key)
if values:
return values
return False

def get_all(self) -> Dict[str, Set[str]]:
"""Get all headers and values stored so far.

Returns:
Dict[str, str]: The headers
"""
return self._headers

def get(self, header_name: str) -> Set[str]:
"""Get header values corresponding to a specific header.

Args:
header_name (str): Header key.

Returns:
Set[str]: Values for the header key
"""
if not header_name:
raise ValueError("Header name cannot be null")
header_name = header_name.lower()
values = self._headers.get(header_name)
if not values:
return set()
return values

def try_add(self, header_name: str, header_value: str) -> bool:
"""Adds values to the header with the specified name if it's not already present

Args:
header_name (str): The name of the header to add values to.
header_value (str): The values to add to the header.

Returns:
bool: If the header value has been added to headers.
"""
if not header_name:
raise ValueError("Header name cannot be null")
if header_value is None:
raise ValueError("Header value cannot be null")
header_name = header_name.lower()
if header_name not in self._headers:
self._headers[header_name] = {header_value}
return True
return False

def add_all(self, headers: HeadersCollection) -> None:
"""Adds the specified headers to the collection.

Args:
headers (Dict[str, str]): The headers to add.
"""
if not headers:
raise ValueError("Headers cannot be null")
for key, values in headers.get_all().items():
for value in values:
self.add(key, value)

def add(self, header_name: str, header_values: Union[str, List[str]]) -> None:
"""Adds values to the header with the specified name.

Args:
header_name (str): The name of the header to add values to.
header_values (List[str]): The values to add to the header.
"""
if not header_name:
raise ValueError("Header name cannot be null")
if header_values is None:
raise ValueError("Header values cannot be null")
if not header_values: # empty list
return
header_name = header_name.lower()
if isinstance(header_values, list):
if header_name in self.SINGLE_VALUE_HEADERS:
self._headers[header_name] = {header_values[0]}
elif values := self.try_get(header_name):
for header_value in header_values:
values.add(header_value) #type: ignore
else:
self._headers[header_name] = set(header_values)
else:
if values := self.try_get(header_name):
values.add(header_values) #type: ignore
else:
self._headers[header_name] = {header_values}

def keys(self) -> List[str]:
"""Gets the header names present in the collection.
Returns:
List[str]: The header names present in the collection.
"""
return list(self._headers.keys())

def count(self):
"""Gets the number of headers present in the collection."""
return len(self._headers)

def remove_value(self, header_name: str, header_value: str) -> Union[bool, Set[str]]:
"""Removes the specified value from the header with the specified name.

Args:
header_name (str): The name of the header to remove the value from.
header_value (str): The value to remove from the header.

Returns:
bool: _description_
"""
if not header_name:
raise ValueError("Header name cannot be null")
if header_value is None:
raise ValueError("Header value cannot be null")
header_name = header_name.lower()
values = self.try_get(header_name)
if values:
values.remove(header_value) #type: ignore
if bool(values):
return values
return self.remove(header_name)

return False

def remove(self, header_name: str) -> Union[bool, Set[str]]:
"""Removes the header with the specified name.

Args:
header_name (str): The name of the header to remove.

Returns:
bool: True if the header has been removed, False otherwise.
"""
if not header_name:
raise ValueError("Header name cannot be null")
header_name = header_name.lower()
if self.contains(header_name):
return self._headers.pop(header_name)
return False

def clear(self) -> None:
"""Removes all headers from the collection.
"""
self._headers.clear()

def contains(self, key: str) -> bool:
"""Checks whether the collection contains a specific header.

Args:
key (str): The name of the header to check for.

Returns:
bool: True if the header is present, false otherwise.
"""
if not key:
raise ValueError("Header name cannot be null")
key = key.lower()
return key in self._headers
52 changes: 9 additions & 43 deletions kiota_abstractions/request_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from stduritemplate import StdUriTemplate

from ._version import VERSION
from .headers_collection import HeadersCollection
from .method import Method
from .request_option import RequestOption
from .serialization import Parsable, SerializationWriter
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(self) -> None:
self.query_parameters: Dict[str, QueryParams] = {}

# The Request Headers
self.headers: Dict[str, Set[str]] = {} # Use set to remove duplicates
self.headers: HeadersCollection = HeadersCollection()

# The Request Body
self.content: Optional[BytesIO] = None
Expand Down Expand Up @@ -95,47 +96,10 @@ def url(self, url: Url) -> None:
@property
def request_headers(self) -> Optional[Dict]:
final = {}
for key, value in self.headers.items():
final[key] = ", ".join(value)
for key, values in self.headers.get_all().items():
final[key] = ', '.join(values)
return final

def add_request_headers(
self, headers_to_add: Optional[Dict[str, Union[str, List[str]]]]
) -> None:
"""Adds headers to the request"""
if headers_to_add:
for key, value in headers_to_add.items():
lowercase_key = key.lower()
if lowercase_key in self.headers:
if isinstance(value, list):
self.headers[lowercase_key] = self.headers[lowercase_key].union(set(value))
else:
self.headers[lowercase_key].add(str(value))
else:
if isinstance(value, list):
self.headers[lowercase_key] = set(value)
else:
self.headers[lowercase_key] = {str(value)}

def try_add_request_header(self, key: str, value: str) -> bool:
"""Try to add an header to the request if it's not already set"""
if key and value:
lowercase_key = key.lower()
if lowercase_key in self.headers:
return False
self.headers[lowercase_key] = {str(value)}
return True
return False

def remove_request_headers(self, key: str) -> None:
"""Removes a request header from the current request

Args:
key (str): The key of the header to remove
"""
if key and key.lower() in self.headers:
del self.headers[key.lower()]

@property
def request_options(self) -> Dict[str, RequestOption]:
"""Gets the request options for the request."""
Expand Down Expand Up @@ -227,13 +191,15 @@ def set_content_from_scalar(
writer_func(None, values)
self._set_content_and_content_type_header(writer, content_type)

def set_stream_content(self, value: BytesIO) -> None:
def set_stream_content(self, value: BytesIO, content_type: Optional[str] = None) -> None:
"""Sets the request body to be a binary stream.

Args:
value (BytesIO): the binary stream
"""
self.try_add_request_header(self.CONTENT_TYPE_HEADER, self.BINARY_CONTENT_TYPE)
if not content_type:
content_type = self.BINARY_CONTENT_TYPE
self.headers.try_add(self.CONTENT_TYPE_HEADER, content_type)
self.content = value

def set_query_string_parameters_from_raw_object(self, q: Optional[QueryParams]) -> None:
Expand Down Expand Up @@ -283,7 +249,7 @@ def _set_content_and_content_type_header(
self, writer: SerializationWriter, content_type: Optional[str]
):
if content_type:
self.try_add_request_header(self.CONTENT_TYPE_HEADER, content_type)
self.headers.try_add(self.CONTENT_TYPE_HEADER, content_type)
self.content = writer.get_serialized_content()

def _decode_uri_string(self, uri: Optional[str]) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ async def test_header_location_authentication(mock_request_information):
allowed_hosts,
)
await provider.authenticate_request(mock_request_information)
assert "api_key" in mock_request_information.headers
assert mock_request_information.headers["api_key"] == {"test_key_string"}
assert "api_key" in mock_request_information.request_headers
assert mock_request_information.headers.get("api_key") == {"test_key_string"}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ async def test_authenticate_request(mock_request_information, mock_access_token_
await auth.authenticate_request(mock_request_information)

assert mock_request_information
assert mock_request_information.headers == {'authorization': {'Bearer SomeToken'}}
assert mock_request_information.headers.get_all() == {'authorization': {'Bearer SomeToken'}}
assert mock_request_information.request_headers == {'authorization': 'Bearer SomeToken'}
Loading