diff --git a/CHANGELOG.md b/CHANGELOG.md index d45a724..e2098d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.0] - 2023-10-10 + +### Added +- Added a content type parameter to the set stream content method in request information. + +### Changed +- Added dedicated HeadersCollection class to manage request headers. + ## [0.8.7] - 2023-10-05 ### Added diff --git a/kiota_abstractions/_version.py b/kiota_abstractions/_version.py index 538ada5..3f8ebe9 100644 --- a/kiota_abstractions/_version.py +++ b/kiota_abstractions/_version.py @@ -1 +1 @@ -VERSION: str = "0.8.7" +VERSION: str = "0.9.0" diff --git a/kiota_abstractions/authentication/api_key_authentication_provider.py b/kiota_abstractions/authentication/api_key_authentication_provider.py index 953a5e5..bcf29ec 100644 --- a/kiota_abstractions/authentication/api_key_authentication_provider.py +++ b/kiota_abstractions/authentication/api_key_authentication_provider.py @@ -74,4 +74,4 @@ async def authenticate_request( url_parts[4] = urlencode(query) request.url = urlunparse(url_parts) elif self.key_location == KeyLocation.Header: - request.add_request_headers({self.parameter_name: self.api_key}) + request.headers.add(self.parameter_name, self.api_key) diff --git a/kiota_abstractions/authentication/base_bearer_token_authentication_provider.py b/kiota_abstractions/authentication/base_bearer_token_authentication_provider.py index ff77c8e..a91a5e5 100644 --- a/kiota_abstractions/authentication/base_bearer_token_authentication_provider.py +++ b/kiota_abstractions/authentication/base_bearer_token_authentication_provider.py @@ -6,6 +6,7 @@ from typing import Any, Dict +from ..headers_collection import HeadersCollection from ..request_information import RequestInformation from .access_token_provider import AccessTokenProvider from .authentication_provider import AuthenticationProvider @@ -36,17 +37,18 @@ async def authenticate_request( if all( [ additional_authentication_context, self.CLAIMS_KEY - in additional_authentication_context, self.AUTHORIZATION_HEADER in request.headers + in additional_authentication_context, + request.headers.contains(self.AUTHORIZATION_HEADER) ] ): - del request.headers[self.AUTHORIZATION_HEADER] + request.headers.remove(self.AUTHORIZATION_HEADER) if not request.request_headers: - request.headers = {} + request.headers = HeadersCollection() - if not self.AUTHORIZATION_HEADER in request.headers: + if not request.headers.contains(self.AUTHORIZATION_HEADER): token = await self.access_token_provider.get_authorization_token( request.url, additional_authentication_context ) if token: - request.add_request_headers({f'{self.AUTHORIZATION_HEADER}': f'Bearer {token}'}) + request.headers.add(f'{self.AUTHORIZATION_HEADER}', f'Bearer {token}') diff --git a/kiota_abstractions/headers_collection.py b/kiota_abstractions/headers_collection.py new file mode 100644 index 0000000..827e952 --- /dev/null +++ b/kiota_abstractions/headers_collection.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from typing import Dict, List, Set, Union + + +class HeadersCollection(): + "Represents a collection of request/response headers" + SINGLE_VALUE_HEADERS: Set[str] = {"content-type", "content-encoding", "content-length"} + + def __init__(self) -> None: + self._headers: Dict[str, Set[str]] = {} + + def try_get(self, key: str) -> Union[bool, Set[str]]: + """Gets the header values corresponding to a specific header name. + + Args: + key (str): Header key. + + Returns: + Union[bool, Set[str]]: The header values for the specified header key or False. + """ + if not key: + raise ValueError("Header name cannot be null") + key = key.lower() + values = self._headers.get(key) + if values: + return values + return False + + def get_all(self) -> Dict[str, Set[str]]: + """Get all headers and values stored so far. + + Returns: + Dict[str, str]: The headers + """ + return self._headers + + def get(self, header_name: str) -> Set[str]: + """Get header values corresponding to a specific header. + + Args: + header_name (str): Header key. + + Returns: + Set[str]: Values for the header key + """ + if not header_name: + raise ValueError("Header name cannot be null") + header_name = header_name.lower() + values = self._headers.get(header_name) + if not values: + return set() + return values + + def try_add(self, header_name: str, header_value: str) -> bool: + """Adds values to the header with the specified name if it's not already present + + Args: + header_name (str): The name of the header to add values to. + header_value (str): The values to add to the header. + + Returns: + bool: If the header value has been added to headers. + """ + if not header_name: + raise ValueError("Header name cannot be null") + if header_value is None: + raise ValueError("Header value cannot be null") + header_name = header_name.lower() + if header_name not in self._headers: + self._headers[header_name] = {header_value} + return True + return False + + def add_all(self, headers: HeadersCollection) -> None: + """Adds the specified headers to the collection. + + Args: + headers (Dict[str, str]): The headers to add. + """ + if not headers: + raise ValueError("Headers cannot be null") + for key, values in headers.get_all().items(): + for value in values: + self.add(key, value) + + def add(self, header_name: str, header_values: Union[str, List[str]]) -> None: + """Adds values to the header with the specified name. + + Args: + header_name (str): The name of the header to add values to. + header_values (List[str]): The values to add to the header. + """ + if not header_name: + raise ValueError("Header name cannot be null") + if header_values is None: + raise ValueError("Header values cannot be null") + if not header_values: # empty list + return + header_name = header_name.lower() + if isinstance(header_values, list): + if header_name in self.SINGLE_VALUE_HEADERS: + self._headers[header_name] = {header_values[0]} + elif values := self.try_get(header_name): + for header_value in header_values: + values.add(header_value) #type: ignore + else: + self._headers[header_name] = set(header_values) + else: + if values := self.try_get(header_name): + values.add(header_values) #type: ignore + else: + self._headers[header_name] = {header_values} + + def keys(self) -> List[str]: + """Gets the header names present in the collection. + Returns: + List[str]: The header names present in the collection. + """ + return list(self._headers.keys()) + + def count(self): + """Gets the number of headers present in the collection.""" + return len(self._headers) + + def remove_value(self, header_name: str, header_value: str) -> Union[bool, Set[str]]: + """Removes the specified value from the header with the specified name. + + Args: + header_name (str): The name of the header to remove the value from. + header_value (str): The value to remove from the header. + + Returns: + bool: _description_ + """ + if not header_name: + raise ValueError("Header name cannot be null") + if header_value is None: + raise ValueError("Header value cannot be null") + header_name = header_name.lower() + values = self.try_get(header_name) + if values: + values.remove(header_value) #type: ignore + if bool(values): + return values + return self.remove(header_name) + + return False + + def remove(self, header_name: str) -> Union[bool, Set[str]]: + """Removes the header with the specified name. + + Args: + header_name (str): The name of the header to remove. + + Returns: + bool: True if the header has been removed, False otherwise. + """ + if not header_name: + raise ValueError("Header name cannot be null") + header_name = header_name.lower() + if self.contains(header_name): + return self._headers.pop(header_name) + return False + + def clear(self) -> None: + """Removes all headers from the collection. + """ + self._headers.clear() + + def contains(self, key: str) -> bool: + """Checks whether the collection contains a specific header. + + Args: + key (str): The name of the header to check for. + + Returns: + bool: True if the header is present, false otherwise. + """ + if not key: + raise ValueError("Header name cannot be null") + key = key.lower() + return key in self._headers diff --git a/kiota_abstractions/request_information.py b/kiota_abstractions/request_information.py index 091c700..b914396 100644 --- a/kiota_abstractions/request_information.py +++ b/kiota_abstractions/request_information.py @@ -9,6 +9,7 @@ from stduritemplate import StdUriTemplate from ._version import VERSION +from .headers_collection import HeadersCollection from .method import Method from .request_option import RequestOption from .serialization import Parsable, SerializationWriter @@ -54,7 +55,7 @@ def __init__(self) -> None: self.query_parameters: Dict[str, QueryParams] = {} # The Request Headers - self.headers: Dict[str, Set[str]] = {} # Use set to remove duplicates + self.headers: HeadersCollection = HeadersCollection() # The Request Body self.content: Optional[BytesIO] = None @@ -95,47 +96,10 @@ def url(self, url: Url) -> None: @property def request_headers(self) -> Optional[Dict]: final = {} - for key, value in self.headers.items(): - final[key] = ", ".join(value) + for key, values in self.headers.get_all().items(): + final[key] = ', '.join(values) return final - def add_request_headers( - self, headers_to_add: Optional[Dict[str, Union[str, List[str]]]] - ) -> None: - """Adds headers to the request""" - if headers_to_add: - for key, value in headers_to_add.items(): - lowercase_key = key.lower() - if lowercase_key in self.headers: - if isinstance(value, list): - self.headers[lowercase_key] = self.headers[lowercase_key].union(set(value)) - else: - self.headers[lowercase_key].add(str(value)) - else: - if isinstance(value, list): - self.headers[lowercase_key] = set(value) - else: - self.headers[lowercase_key] = {str(value)} - - def try_add_request_header(self, key: str, value: str) -> bool: - """Try to add an header to the request if it's not already set""" - if key and value: - lowercase_key = key.lower() - if lowercase_key in self.headers: - return False - self.headers[lowercase_key] = {str(value)} - return True - return False - - def remove_request_headers(self, key: str) -> None: - """Removes a request header from the current request - - Args: - key (str): The key of the header to remove - """ - if key and key.lower() in self.headers: - del self.headers[key.lower()] - @property def request_options(self) -> Dict[str, RequestOption]: """Gets the request options for the request.""" @@ -227,13 +191,15 @@ def set_content_from_scalar( writer_func(None, values) self._set_content_and_content_type_header(writer, content_type) - def set_stream_content(self, value: BytesIO) -> None: + def set_stream_content(self, value: BytesIO, content_type: Optional[str] = None) -> None: """Sets the request body to be a binary stream. Args: value (BytesIO): the binary stream """ - self.try_add_request_header(self.CONTENT_TYPE_HEADER, self.BINARY_CONTENT_TYPE) + if not content_type: + content_type = self.BINARY_CONTENT_TYPE + self.headers.try_add(self.CONTENT_TYPE_HEADER, content_type) self.content = value def set_query_string_parameters_from_raw_object(self, q: Optional[QueryParams]) -> None: @@ -283,7 +249,7 @@ def _set_content_and_content_type_header( self, writer: SerializationWriter, content_type: Optional[str] ): if content_type: - self.try_add_request_header(self.CONTENT_TYPE_HEADER, content_type) + self.headers.try_add(self.CONTENT_TYPE_HEADER, content_type) self.content = writer.get_serialized_content() def _decode_uri_string(self, uri: Optional[str]) -> str: diff --git a/tests/authentication/test_api_key_authentication_provider.py b/tests/authentication/test_api_key_authentication_provider.py index 02a6e1a..e9e07b6 100644 --- a/tests/authentication/test_api_key_authentication_provider.py +++ b/tests/authentication/test_api_key_authentication_provider.py @@ -61,5 +61,5 @@ async def test_header_location_authentication(mock_request_information): allowed_hosts, ) await provider.authenticate_request(mock_request_information) - assert "api_key" in mock_request_information.headers - assert mock_request_information.headers["api_key"] == {"test_key_string"} + assert "api_key" in mock_request_information.request_headers + assert mock_request_information.headers.get("api_key") == {"test_key_string"} diff --git a/tests/authentication/test_base_bearer_token_authentication.py b/tests/authentication/test_base_bearer_token_authentication.py index 2861004..0a45b20 100644 --- a/tests/authentication/test_base_bearer_token_authentication.py +++ b/tests/authentication/test_base_bearer_token_authentication.py @@ -22,5 +22,5 @@ async def test_authenticate_request(mock_request_information, mock_access_token_ await auth.authenticate_request(mock_request_information) assert mock_request_information - assert mock_request_information.headers == {'authorization': {'Bearer SomeToken'}} + assert mock_request_information.headers.get_all() == {'authorization': {'Bearer SomeToken'}} assert mock_request_information.request_headers == {'authorization': 'Bearer SomeToken'} diff --git a/tests/test_request_header.py b/tests/test_request_header.py new file mode 100644 index 0000000..abaf5f1 --- /dev/null +++ b/tests/test_request_header.py @@ -0,0 +1,182 @@ +import pytest + +from kiota_abstractions.headers_collection import HeadersCollection + +def test_defensive(): + """Tests initialization of RequestHeader objects + """ + headers = HeadersCollection() + with pytest.raises(ValueError): + headers.try_get(None) + with pytest.raises(ValueError): + headers.try_get("") + with pytest.raises(ValueError): + headers.get(None) + with pytest.raises(ValueError): + headers.get("") + with pytest.raises(ValueError): + headers.try_add(None, "value") + with pytest.raises(ValueError): + headers.try_add("", "value") + with pytest.raises(ValueError): + headers.try_add("header", None) + with pytest.raises(ValueError): + headers.add_all(None) + with pytest.raises(ValueError): + headers.add(None, "value") + with pytest.raises(ValueError): + headers.add("", "value") + with pytest.raises(ValueError): + headers.add("header", None) + with pytest.raises(ValueError): + headers.remove_value(None, "value") + with pytest.raises(ValueError): + headers.remove_value("", "value") + with pytest.raises(ValueError): + headers.remove_value("header", None) + with pytest.raises(ValueError): + headers.remove(None) + with pytest.raises(ValueError): + headers.remove("") + with pytest.raises(ValueError): + headers.contains(None) + with pytest.raises(ValueError): + headers.contains("") + +def test_normalizes_casing(): + headers = HeadersCollection() + headers.add("heaDER1", "value1") + assert {"value1"} <= headers.try_get("header1") + assert {"value1"} <= headers.get("header1") + +def test_adds_to_non_existent_header(): + """Tests adding a header to a non-existent header + """ + headers = HeadersCollection() + headers.add("header1", "value1") + assert {"value1"} <= headers.try_get("header1") + assert {"value1"} <= headers.get("header1") + assert headers.contains("header1") + assert headers.count() == 1 + +def test_try_adds_to_non_existent_header(): + """Tests try adding a header to a non-existent header + """ + headers = HeadersCollection() + assert headers.try_add("header1", "value1") + assert {"value1"} <= headers.try_get("header1") + assert {"value1"} <= headers.get("header1") + assert headers.contains("header1") + assert headers.count() == 1 + +def test_adds_to_existing_header(): + """Tests adding a header to an existing header + """ + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header1", "value2") + assert {"value1", "value2"} <= headers.try_get("header1") + assert {"value1", "value2"} <= headers.get("header1") + assert headers.contains("header1") + assert headers.count() == 1 + +def test_try_adds_to_existing_header(): + """Tests try adding a header to an existing header + """ + headers = HeadersCollection() + assert headers.try_add("header1", "value1") + assert not headers.try_add("header1", "value2") + assert {"value1"} <= headers.try_get("header1") + assert {"value1"} <= headers.get("header1") + assert headers.contains("header1") + assert headers.count() == 1 + +def test_add_single_value_header_to_existing_header(): + """Tests adding a single value header to an existing header + """ + headers = HeadersCollection() + headers.add("content-type", "value1") + headers.add("content-type", "value2") + assert {"value2"} <= headers.try_get("content-type") + assert {"value2"} <= headers.get("content-type") + assert headers.contains("content-type") + assert headers.count() == 1 + +def test_try_add_single_value_header_to_existing_header(): + """Tests adding a single value header to an existing header + """ + headers = HeadersCollection() + headers.try_add("content-type", "value1") + headers.try_add("content-type", "value2") + assert {"value1"} <= headers.try_get("content-type") + assert {"value1"} <= headers.get("content-type") + assert headers.contains("content-type") + assert headers.count() == 1 + +def test_removes_value_from_existing_header(): + """Tests removing a value from an existing header + """ + headers = HeadersCollection() + headers.remove_value("header1", "value1") + headers.add("header1", "value1") + headers.add("header1", "value2") + assert headers.contains("header1") + assert headers.count() == 1 + headers.remove_value("header1", "value1") + assert {"value2"} <= headers.try_get("header1") + headers.remove_value("header1", "value2") + assert not headers.contains("header1") + assert headers.count() == 0 + +def test_removes_header(): + """Tests removing a header + """ + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header1", "value2") + assert headers.contains("header1") + assert headers.count() == 1 + headers.remove("header1") + assert not headers.contains("header1") + assert headers.count() == 0 + +def test_clears_headers(): + """Tests clearing headers + """ + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header1", "value2") + headers.add("header2", "value3") + headers.add("header2", "value4") + assert headers.contains("header1") + assert headers.contains("header2") + assert headers.count() == 2 + headers.clear() + assert not headers.contains("header1") + assert not headers.contains("header2") + assert headers.count() == 0 + assert headers.keys() == [] + +def test_adds_headers_from_instance(): + """Tests adding headers from another instance + """ + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header1", "value2") + headers.add("header2", "value3") + headers.add("header2", "value4") + assert headers.contains("header1") + assert headers.contains("header2") + assert headers.count() == 2 + headers2 = HeadersCollection() + headers2.add("header3", "value5") + headers2.add("header3", "value6") + headers2.add("header4", "value7") + headers2.add("header4", "value8") + headers.add_all(headers2) + assert headers.contains("header1") + assert headers.contains("header2") + assert headers.contains("header3") + assert headers.contains("header4") + assert headers.count() == 4 + assert headers.keys() == ["header1", "header2", "header3", "header4"] \ No newline at end of file diff --git a/tests/test_request_information.py b/tests/test_request_information.py index 13af504..8907374 100644 --- a/tests/test_request_information.py +++ b/tests/test_request_information.py @@ -1,6 +1,7 @@ import pytest from kiota_abstractions.request_information import RequestInformation +from kiota_abstractions.headers_collection import HeadersCollection def test_initialization(): @@ -10,68 +11,45 @@ def test_initialization(): assert request_info assert not request_info.path_parameters assert not request_info.query_parameters - assert not request_info.headers assert not request_info.request_options assert not request_info.url_template assert not request_info.http_method assert not request_info.content + assert request_info.headers assert request_info.RAW_URL_KEY == 'request-raw-url' assert request_info.BINARY_CONTENT_TYPE == 'application/octet-stream' assert request_info.CONTENT_TYPE_HEADER == 'Content-Type' -def test_add_request_headers_null(mock_request_information): - """Tests adding a null request header - """ - mock_request_information.add_request_headers(None) - assert mock_request_information.headers == {} - - -def test_add_request_headers_value_string(mock_request_information): +def test_add_request_headers(mock_request_information): """Tests adding a request header with a string value """ - mock_request_information.add_request_headers({"header1": "value1"}) - mock_request_information.add_request_headers({"header2": "value2"}) - assert {"value1"} <= mock_request_information.headers["header1"] - assert {"value2"} <= mock_request_information.headers["header2"] - mock_request_information.add_request_headers({"header1": "value3"}) - assert {"value1", "value3"} <= mock_request_information.headers["header1"] - - -def test_add_request_headers_value_list(mock_request_information): - """Tests adding a request header with a list value - """ - mock_request_information.add_request_headers({"header1": ["value1", "value2"]}) - mock_request_information.add_request_headers({"header2": ["value3", "value4"]}) - assert {"value1", "value2"} <= mock_request_information.headers["header1"] - assert {"value3", "value4"} <= mock_request_information.headers["header2"] - mock_request_information.add_request_headers({"header1": ["value5", "value6"]}) - assert {"value1", "value2", "value5", "value6"} <= mock_request_information.headers["header1"] - - -def test_add_request_headers_value_normalizes_cases(mock_request_information): - """Tests adding a request header normalizes the casing of the header name - """ - mock_request_information.add_request_headers({"heaDER1": "value1"}) - mock_request_information.add_request_headers({"HEAder2": "value2"}) - assert {"value1"} <= mock_request_information.headers["header1"] - assert {"value2"} <= mock_request_information.headers["header2"] - mock_request_information.add_request_headers({"HEADER1": "value3"}) - assert {"value1", "value3"} <= mock_request_information.headers["header1"] + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header2", "value2") + mock_request_information.headers.add_all(headers) + assert {"value1"} <= mock_request_information.headers.get("header1") + assert {"value2"} <= mock_request_information.headers.get("header2") + header2 = HeadersCollection() + header2.add("header1", "value3") + mock_request_information.headers.add_all(header2) + assert {"value1", "value3"} <= mock_request_information.headers.get("header1") def test_request_headers(mock_request_information): """Test the final request headers """ - mock_request_information.add_request_headers({"header1": ["value1", "value2"]}) - mock_request_information.add_request_headers({"header2": ["value3", "value4"]}) + headers = HeadersCollection() + headers.add("header1", ["value1", "value2"]) + headers.add("header2", ["value3", "value4"]) + mock_request_information.headers.add_all(headers) assert "value1" in mock_request_information.request_headers["header1"] assert "value2" in mock_request_information.request_headers["header1"] assert "value3" in mock_request_information.request_headers["header2"] assert "value4" in mock_request_information.request_headers["header2"] - mock_request_information.add_request_headers( - {"header1": ["value1", "value2", "value5", "value6"]} - ) + headers2 = HeadersCollection() + headers2.add("header1", ["value1", "value2", "value5", "value6"]) + mock_request_information.headers.add_all(headers2) assert "value5" in mock_request_information.request_headers["header1"] assert "value6" in mock_request_information.request_headers["header1"] @@ -84,24 +62,19 @@ def test_request_headers(mock_request_information): assert "value3" in mock_request_information.request_headers["header2"] assert "value4" in mock_request_information.request_headers["header2"] -def test__try_add_request_header(mock_request_information): - """Test the final request header after try_add - """ - assert mock_request_information.try_add_request_header("header1", "value1") == True - assert mock_request_information.try_add_request_header("header1", "value2") == False - assert "value1" in mock_request_information.request_headers["header1"] - def test_remove_request_headers(mock_request_information): """Tests removing a request header """ - mock_request_information.add_request_headers({"header1": "value1"}) - mock_request_information.add_request_headers({"header2": "value2"}) - assert mock_request_information.headers["header1"] == {"value1"} - assert mock_request_information.headers["header2"] == {"value2"} - mock_request_information.remove_request_headers("header1") - mock_request_information.remove_request_headers("header3") - assert 'header1' not in mock_request_information.headers - assert mock_request_information.headers["header2"] == {"value2"} + headers = HeadersCollection() + headers.add("header1", "value1") + headers.add("header2", "value2") + mock_request_information.headers.add_all(headers) + assert mock_request_information.headers.get("header1") == {"value1"} + assert mock_request_information.headers.get("header2") == {"value2"} + mock_request_information.headers.remove("header1") + mock_request_information.headers.remove("header3") + assert 'header1' not in mock_request_information.request_headers + assert mock_request_information.headers.try_get("header2") == {"value2"} def test_set_stream_content(mock_request_information): @@ -109,4 +82,4 @@ def test_set_stream_content(mock_request_information): """ mock_request_information.set_stream_content(b'stream') assert mock_request_information.content == b'stream' - assert mock_request_information.headers["content-type"] == {"application/octet-stream"} + assert mock_request_information.headers.get("content-type") == {"application/octet-stream"}