Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow service account authentication #16

Merged
merged 6 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 127 additions & 18 deletions certbot_dns_stackit/stackit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
from dataclasses import dataclass
from typing import Optional, List, Callable

from typing import Optional, List, Callable, TypedDict
import jwt
import jwt.help
import json
import time
import uuid
import requests

from certbot import errors
from certbot.plugins import dns_common

Expand All @@ -25,6 +30,25 @@
records: List[Record]


class ServiceFileCredentials(TypedDict):
"""
Represents the credentials obtained from a service file for authentication.

Attributes:
iss (str): The issuer of the token, typically the email address of the service account.
sub (str): The subject of the token, usually the same as `iss` unless acting on behalf of another user.
aud (str): The audience for the token, indicating the intended recipient, usually the authentication URL.
kid (str): The key ID used for identifying the private key corresponding to the public key.
privateKey (str): The private key used to sign the authentication token.
"""

iss: str
sub: str
aud: str
kid: str
privateKey: str


class _StackitClient(object):
"""
A client to interact with the STACKIT DNS API.
Expand Down Expand Up @@ -137,12 +161,12 @@
:param domain: The domain (zone dnsName) for which the zone ID is needed.
:return: The ID of the zone.
"""
parts = domain.split('.')
parts = domain.split(".")

# we are searching for the best matching zone. We can do that by iterating over the parts of the domain
# from left to right.
for i in range(len(parts)):
subdomain = '.'.join(parts[i:])
subdomain = ".".join(parts[i:])
res = requests.get(
f"{self.base_url}/v1/projects/{self.project_id}/zones?dnsName[eq]={subdomain}&active[eq]=true",
headers=self.headers,
Expand Down Expand Up @@ -227,12 +251,16 @@

Attributes:
credentials: A configuration object that holds STACKIT API credentials.
service_account: A configuration object that holds the service account file path.
"""

def __init__(self, *args, **kwargs):
"""Initialize the Authenticator by calling the parent's init method."""
super(Authenticator, self).__init__(*args, **kwargs)

self.credentials = None
self.service_account = None

@classmethod
def add_parser_arguments(cls, add: Callable, **kwargs):
"""
Expand All @@ -244,20 +272,25 @@
super(Authenticator, cls).add_parser_arguments(
add, default_propagation_seconds=900
)
add("service-account", help="Service account file path")
add("credentials", help="STACKIT credentials INI file.")
add("project-id", help="STACKIT project ID")

def _setup_credentials(self):
"""Set up and configure the STACKIT credentials."""
self.credentials = self._configure_credentials(
"credentials",
"STACKIT credentials for the STACKIT DNS API",
{
"project_id": "Specifies the project id of the STACKIT project.",
"auth_token": "Defines the authentication token for the STACKIT DNS API. Keep in mind that the "
"service account to this token need to have project edit permissions as we create txt "
"records in the zone",
},
)
"""Set up and configure the STACKIT credentials based on provided input."""
if self.conf("service_account") is not None:
self.service_account = self.conf("service_account")
else:
self.credentials = self._configure_credentials(
"credentials",
"STACKIT credentials for the STACKIT DNS API",
{
"project_id": "Specifies the project id of the STACKIT project.",
"auth_token": "Defines the authentication token for the STACKIT DNS API. Keep in mind that the "
"service account to this token need to have project edit permissions as we create txt "
"records in the zone",
},
)

def _perform(self, domain: str, validation_name: str, validation: str):
"""
Expand All @@ -281,16 +314,92 @@

def _get_stackit_client(self) -> _StackitClient:
"""
Instantiate and return a StackitClient object.
Instantiate and return a StackitClient object based on the authentication method.

:return: A _StackitClient instance to interact with the STACKIT DNS API.
:return: A StackitClient object.
"""
base_url = "https://dns.api.stackit.cloud"
if self.credentials.conf("base_url") is not None:
if self.credentials and self.credentials.conf("base_url") is not None:
base_url = self.credentials.conf("base_url")

if self.service_account is not None:
access_token = self._generate_jwt_token(self.conf("service_account"))
if access_token:
return _StackitClient(access_token, self.conf("project-id"), base_url)
return _StackitClient(
self.credentials.conf("auth_token"),
self.credentials.conf("project_id"),
base_url,
)

def _load_service_file(self, file_path: str) -> Optional[ServiceFileCredentials]:
"""
Load service file credentials from a specified file path.

:param file_path: The path to the service account file.
:return: Service file credentials if the file is found and valid, None otherwise.
"""
try:
with open(file_path, "r") as file:
return json.load(file)["credentials"]
except FileNotFoundError:
logging.error(f"File not found: {file_path}")
return None

def _generate_jwt(self, credentials: ServiceFileCredentials) -> str:
"""
Generate a JWT token using the provided service file credentials.

:param credentials: The service file credentials.
:return: A JWT token as a string.
"""
payload = {
"iss": credentials["iss"],
"sub": credentials["sub"],
"aud": credentials["aud"],
"exp": int(time.time()) + 900,
"iat": int(time.time()),
"jti": str(uuid.uuid4()),
}
headers = {"kid": credentials["kid"]}
return jwt.encode(
payload, credentials["privateKey"], algorithm="RS512", headers=headers
Fixed Show fixed Hide fixed
)

def _request_access_token(self, jwt_token: str) -> str:
"""
Request an access token using a JWT token.

:param jwt_token: The JWT token used to request the access token.
:return: An access token if the request is successful, None otherwise.
"""
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": jwt_token,
}
try:
response = requests.post(
"https://service-account.api.stackit.cloud/token",
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
return response.json().get("access_token")
except requests.exceptions.RequestException as e:
raise errors.PluginError(f"Failed to request access token: {e}")

def _generate_jwt_token(self, file_path: str) -> Optional[str]:
"""
Generate a JWT token and request an access token using the service file at the given path.

:param file_path: The path to the service account file.
:return: An access token if the process is successful, None otherwise.
"""
credentials = self._load_service_file(file_path)
if credentials is None:
raise errors.PluginError("Failed to load service file credentials.")
jwt_token = self._generate_jwt(credentials)
bearer = self._request_access_token(jwt_token)
if bearer is None:
raise errors.PluginError("Could not obtain access token.")
return bearer
138 changes: 136 additions & 2 deletions certbot_dns_stackit/test_stackit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import unittest
from unittest.mock import patch, Mock
from unittest.mock import patch, Mock, mock_open
import json
import jwt
from requests.models import Response
from requests.exceptions import HTTPError

from certbot import errors
from certbot_dns_stackit.stackit import _StackitClient, RRSet, Record, Authenticator
Expand Down Expand Up @@ -214,13 +218,34 @@ def setUp(self):
mock_name = Mock()
self.authenticator = Authenticator(mock_config, mock_name)

@patch.object(Authenticator, "conf")
@patch.object(Authenticator, "_configure_credentials")
def test_setup_credentials(self, mock_configure_credentials):
def test_setup_credentials_with_service_account(
self, mock_configure_credentials, mock_conf
):
# Simulate `service_account` being set
mock_conf.return_value = "service_account_value"

self.authenticator._setup_credentials()

# Assert _configure_credentials was not called
mock_configure_credentials.assert_not_called()
# Assert service_account is set correctly
self.assertEqual(self.authenticator.service_account, "service_account_value")

@patch.object(Authenticator, "conf")
@patch.object(Authenticator, "_configure_credentials")
def test_setup_credentials_without_service_account(
self, mock_configure_credentials, mock_conf
):
# Simulate `service_account` not being set
mock_conf.return_value = None
mock_creds = Mock()
mock_configure_credentials.return_value = mock_creds

self.authenticator._setup_credentials()

# Assert _configure_credentials was called with the correct arguments
mock_configure_credentials.assert_called_once_with(
"credentials",
"STACKIT credentials for the STACKIT DNS API",
Expand All @@ -231,6 +256,7 @@ def test_setup_credentials(self, mock_configure_credentials):
"records in the zone",
},
)
# Assert credentials are set correctly
self.assertEqual(self.authenticator.credentials, mock_creds)

@patch.object(Authenticator, "_get_stackit_client")
Expand Down Expand Up @@ -261,6 +287,114 @@ def test_cleanup(self, mock_get_client):
"test_domain", "validation_name_test", "validation_test"
)

@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"credentials": {"iss": "test_iss", "sub": "test_sub", "aud": "test_aud", "kid": "test_kid", "privateKey": "test_private_key"}}',
)
@patch("json.load", lambda x: json.loads(x.read()))
def test_load_service_file(self, mock_load_service_file):
expected_credentials = {
"iss": "test_iss",
"sub": "test_sub",
"aud": "test_aud",
"kid": "test_kid",
"privateKey": "test_private_key",
}

credentials = self.authenticator._load_service_file("dummy_path")
self.assertEqual(credentials, expected_credentials)

@patch("builtins.open", side_effect=FileNotFoundError())
@patch("logging.error")
def test_load_service_file_not_found(self, mock_log, mock_file):
result = self.authenticator._load_service_file("nonexistent_path")
self.assertIsNone(result)
mock_log.assert_called()

@patch("jwt.encode")
def test_generate_jwt(self, mock_jwt_encode):
credentials = {
"iss": "issuer",
"sub": "subject",
"aud": "audience",
"kid": "key_id",
"privateKey": "private_key",
}
self.authenticator._generate_jwt(credentials)
mock_jwt_encode.assert_called()

def test_generate_jwt_fail(self):
credentials = {
"iss": "issuer",
"sub": "subject",
"aud": "audience",
"kid": "key_id",
"privateKey": "not_a_valid_key",
}
with self.assertRaises(jwt.exceptions.InvalidKeyError):
token = self.authenticator._generate_jwt(credentials)
self.assertIsNone(token)

@patch("requests.post")
def test_request_access_token_success(self, mock_post):
mock_response = mock_post.return_value
mock_response.raise_for_status = (
lambda: None
) # Mock raise_for_status to do nothing
mock_response.json.return_value = {"access_token": "mocked_access_token"}

result = self.authenticator._request_access_token("jwt_token_example")

# Assertions
mock_post.assert_called_once_with(
"https://service-account.api.stackit.cloud/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": "jwt_token_example",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
self.assertEqual(result, "mocked_access_token")

@patch("requests.post")
def test_request_access_token_failure_raises_http_error(self, mock_post):
mock_response = Response()
mock_response.status_code = 403
mock_post.return_value = mock_response
mock_response.raise_for_status = lambda: (_ for _ in ()).throw(HTTPError())

with self.assertRaises(errors.PluginError):
self.authenticator._request_access_token("jwt_token_example")

mock_post.assert_called_once()

@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"credentials": {"iss": "test_iss", "sub": "test_sub", "aud": "test_aud", "kid": "test_kid", "privateKey": "test_private_key"}}',
)
@patch.object(Authenticator, "_request_access_token")
@patch.object(Authenticator, "_generate_jwt")
@patch.object(Authenticator, "_load_service_file")
def test_generate_jwt_token_success(
self,
mock_load_service_file,
mock_generate_jwt,
mock_request_access_token,
mock_open,
):
mock_load_service_file.return_value = {"dummy": "credentials"}
mock_generate_jwt.return_value = "jwt_token_example"
mock_request_access_token.return_value = "access_token_example"

result = self.authenticator._generate_jwt_token("path/to/service/file")

self.assertEqual(result, "access_token_example")
mock_load_service_file.assert_called_once_with("path/to/service/file")
mock_generate_jwt.assert_called_once_with({"dummy": "credentials"})
mock_request_access_token.assert_called_once_with("jwt_token_example")


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ install_requires =
black
click==8.1.7
coverage
PyJWT==2.9.0

[options.entry_points]
certbot.plugins =
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"black",
"click==8.1.7",
"coverage",
"PyJWT==2.9.0"
]

# read the contents of your README file
Expand Down
Loading