Skip to content

Commit

Permalink
Merge pull request #113 from microsoft/feature/CAE-support
Browse files Browse the repository at this point in the history
Feature/cae support
  • Loading branch information
samwelkanda authored Aug 30, 2023
2 parents 4f222cf + 03b30d8 commit d568922
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added opentelemetry to support observability.
- Added an additional parameter to authentication methods to carry contextual information.

## [0.7.1] - 2023-08-09

Expand Down
12 changes: 11 additions & 1 deletion kiota_abstractions/authentication/access_token_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from abc import ABC, abstractmethod
from typing import Any, Dict

from .allowed_hosts_validator import AllowedHostsValidator

Expand All @@ -8,12 +15,15 @@ class AccessTokenProvider(ABC):
"""

@abstractmethod
async def get_authorization_token(self, uri: str) -> str:
async def get_authorization_token(
self, uri: str, additional_authentication_context: Dict[str, Any] = {}
) -> str:
"""This method is called by the BaseBearerTokenAuthenticationProvider class to get the
access token.
Args:
uri (str): The target URI to get an access token for.
additional_authentication_context (dict):
Returns:
str: The access token to use for the request.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from typing import Any, Dict

from ..request_information import RequestInformation
from .authentication_provider import AuthenticationProvider

Expand All @@ -9,10 +17,15 @@ class AnonymousAuthenticationProvider(AuthenticationProvider):
AuthenticationProvider (ABC): The abstract base class that this class implements
"""

async def authenticate_request(self, request: RequestInformation) -> None:
async def authenticate_request(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any] = {}
) -> None:
"""Authenticates the provided request information
Args:
request (RequestInformation): Request information object
additional_authentication_context (dict):
"""
return
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from enum import Enum
from typing import List
from typing import Any, Dict, List
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

from kiota_abstractions.request_information import RequestInformation
Expand Down Expand Up @@ -47,7 +53,11 @@ def __init__(
self.parameter_name = parameter_name
self.allowed_hosts_validator = AllowedHostsValidator(allowed_hosts)

async def authenticate_request(self, request: RequestInformation) -> None:
async def authenticate_request(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any] = {}
) -> None:
"""
Ensures that the API key is placed in the correct location for a request.
"""
Expand Down
14 changes: 13 additions & 1 deletion kiota_abstractions/authentication/authentication_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from abc import ABC, abstractmethod
from typing import Any, Dict

from ..request_information import RequestInformation

Expand All @@ -9,10 +16,15 @@ class AuthenticationProvider(ABC):
"""

@abstractmethod
async def authenticate_request(self, request: RequestInformation) -> None:
async def authenticate_request(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any] = {}
) -> None:
"""Authenticates the application request
Args:
request (RequestInformation): The request to authenticate
additional_authentication_context (dict):
"""
pass
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All Rights Reserved.
# Licensed under the MIT License.
# See License in the project root for license information.
# ------------------------------------------------------------------------------

from typing import Any, Dict

from ..request_information import RequestInformation
from .access_token_provider import AccessTokenProvider
from .authentication_provider import AuthenticationProvider
Expand All @@ -7,11 +15,16 @@ class BaseBearerTokenAuthenticationProvider(AuthenticationProvider):
"""Provides a base class for implementing AuthenticationProvider for Bearer token scheme.
"""
AUTHORIZATION_HEADER = "Authorization"
CLAIMS_KEY = "claims"

def __init__(self, access_token_provider: AccessTokenProvider) -> None:
self.access_token_provider = access_token_provider

async def authenticate_request(self, request: RequestInformation) -> None:
async def authenticate_request(
self,
request: RequestInformation,
additional_authentication_context: Dict[str, Any] = {}
) -> None:
"""Authenticates the provided RequestInformation instance using the provided
authorization token
Expand All @@ -20,10 +33,20 @@ async def authenticate_request(self, request: RequestInformation) -> None:
"""
if not request:
raise Exception("Request cannot be null")
if all(
[
additional_authentication_context, self.CLAIMS_KEY
in additional_authentication_context, self.AUTHORIZATION_HEADER in request.headers
]
):
del request.headers[self.AUTHORIZATION_HEADER]

if not request.request_headers:
request.headers = {}

if not self.AUTHORIZATION_HEADER in request.headers:
token = await self.access_token_provider.get_authorization_token(request.url)
token = await self.access_token_provider.get_authorization_token(
request.url, additional_authentication_context
)
if token:
request.add_request_headers({f'{self.AUTHORIZATION_HEADER}': f'Bearer {token}'})
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ class MockAccessTokenProvider(AccessTokenProvider):
def __init__(self):
self.token = None

async def get_authorization_token(self, url: str) -> str:
async def get_authorization_token(
self,
url: str,
additional_authentication_context: Dict[str, Any] = {}
) -> str:
return "SomeToken"

def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
Expand Down

0 comments on commit d568922

Please sign in to comment.