diff --git a/charmcraft.yaml b/charmcraft.yaml index b83abfa7..92a1269c 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -16,6 +16,8 @@ containers: mounts: - storage: config location: /nms/config + - storage: certs + location: /support/TLS resources: nms-image: @@ -31,6 +33,9 @@ storage: config: type: filesystem minimum-size: 5M + certs: + type: filesystem + minimum-size: 1M requires: ingress: @@ -46,6 +51,8 @@ requires: interface: fiveg_n4 logging: interface: loki_push_api + certificates: + interface: tls-certificates provides: sdcore_config: diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py new file mode 100644 index 00000000..10ca8731 --- /dev/null +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -0,0 +1,1645 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Charm library for managing TLS certificates (V4). + +This library contains the Requires and Provides classes for handling the tls-certificates +interface. + +Pre-requisites: + - Juju >= 3.0 + - cryptography >= 43.0.0 + - pydantic + +Learn more on how-to use the TLS Certificates interface library by reading the documentation: +- https://charmhub.io/tls-certificates-interface/ + +""" # noqa: D214, D405, D411, D416 + +import copy +import ipaddress +import json +import logging +import uuid +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import FrozenSet, List, MutableMapping, Optional, Tuple, Union + +from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent +from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import ( + Application, + ModelError, + Relation, + SecretNotFoundError, + Unit, +) +from pydantic import BaseModel, ConfigDict, ValidationError + +# The unique Charmhub library identifier, never change it +LIBID = "afd8c2bccf834997afce12c2706d2ede" + +# Increment this major API version when introducing breaking changes +LIBAPI = 4 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 1 + +PYDEPS = ["cryptography", "pydantic"] + +logger = logging.getLogger(__name__) + + +class TLSCertificatesError(Exception): + """Base class for custom errors raised by this library.""" + + +class DataValidationError(TLSCertificatesError): + """Raised when data validation fails.""" + + +class _DatabagModel(BaseModel): + """Base databag model.""" + + model_config = ConfigDict( + # tolerate additional keys in databag + extra="ignore", + # Allow instantiating this class by field name (instead of forcing alias). + populate_by_name=True, + # Custom config key: whether to nest the whole datastructure (as json) + # under a field or spread it out at the toplevel. + _NEST_UNDER=None, + ) # type: ignore + """Pydantic config.""" + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + nest_under = cls.model_config.get("_NEST_UNDER") + if nest_under: + return cls.model_validate(json.loads(databag[nest_under])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {(f.alias or n) for n, f in cls.model_fields.items()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.model_validate_json(json.dumps(data)) + except ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + Args: + databag: The databag to write to. + clear: Whether to clear the databag before writing. + + Returns: + MutableMapping: The databag. + """ + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + nest_under = self.model_config.get("_NEST_UNDER") + if nest_under: + databag[nest_under] = self.model_dump_json( + by_alias=True, + # skip keys whose values are default + exclude_defaults=True, + ) + return databag + + dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + return databag + + +class _Certificate(BaseModel): + """Certificate model.""" + + ca: str + certificate_signing_request: str + certificate: str + chain: Optional[List[str]] = None + revoked: Optional[bool] = None + + def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": + """Convert to a ProviderCertificate.""" + return ProviderCertificate( + relation_id=relation_id, + certificate=Certificate.from_string(self.certificate), + certificate_signing_request=CertificateSigningRequest.from_string( + self.certificate_signing_request + ), + ca=Certificate.from_string(self.ca), + chain=[Certificate.from_string(certificate) for certificate in self.chain] + if self.chain + else [], + revoked=self.revoked, + ) + + +class _CertificateSigningRequest(BaseModel): + """Certificate signing request model.""" + + certificate_signing_request: str + ca: Optional[bool] + + +class _ProviderApplicationData(_DatabagModel): + """Provider application data model.""" + + certificates: List[_Certificate] + + +class _RequirerData(_DatabagModel): + """Requirer data model. + + The same model is used for the unit and application data. + """ + + certificate_signing_requests: List[_CertificateSigningRequest] + + +class Mode(Enum): + """Enum representing the mode of the certificate request. + + UNIT (default): Request a certificate for the unit. + Each unit will have its own private key and certificate. + APP: Request a certificate for the application. + The private key and certificate will be shared by all units. + """ + + UNIT = 1 + APP = 2 + + +@dataclass(frozen=True) +class PrivateKey: + """This class represents a private key.""" + + raw: str + + def __str__(self): + """Return the private key as a string.""" + return self.raw + + @classmethod + def from_string(cls, private_key: str) -> "PrivateKey": + """Create a PrivateKey object from a private key.""" + return cls(raw=private_key.strip()) + + +@dataclass(frozen=True) +class Certificate: + """This class represents a certificate.""" + + raw: str + common_name: str + expiry_time: datetime + validity_start_time: datetime + is_ca: bool = False + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + + def __str__(self) -> str: + """Return the certificate as a string.""" + return self.raw + + @classmethod + def from_string(cls, certificate: str) -> "Certificate": + """Create a Certificate object from a certificate.""" + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate: %s", e) + raise TLSCertificatesError("Could not load certificate") + + common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + state_or_province_name = certificate_object.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + organization_name = certificate_object.subject.get_attributes_for_oid( + NameOID.ORGANIZATION_NAME + ) + organizational_unit = certificate_object.subject.get_attributes_for_oid( + NameOID.ORGANIZATIONAL_UNIT_NAME + ) + email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + sans_dns: List[str] = [] + sans_ip: List[str] = [] + sans_oid: List[str] = [] + try: + sans = certificate_object.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ).value + for san in sans: + if isinstance(san, x509.DNSName): + sans_dns.append(san.value) + if isinstance(san, x509.IPAddress): + sans_ip.append(str(san.value)) + if isinstance(san, x509.RegisteredID): + sans_oid.append(str(san.value)) + except x509.ExtensionNotFound: + logger.debug("No SANs found in certificate") + sans_dns = [] + sans_ip = [] + sans_oid = [] + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + is_ca = False + try: + is_ca = certificate_object.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value.ca # type: ignore[reportAttributeAccessIssue] + except x509.ExtensionNotFound: + pass + + return cls( + raw=certificate.strip(), + common_name=str(common_name[0].value), + is_ca=is_ca, + country_name=str(country_name[0].value) if country_name else None, + state_or_province_name=str(state_or_province_name[0].value) + if state_or_province_name + else None, + locality_name=str(locality_name[0].value) if locality_name else None, + organization=str(organization_name[0].value) if organization_name else None, + organizational_unit=str(organizational_unit[0].value) if organizational_unit else None, + email_address=str(email_address[0].value) if email_address else None, + sans_dns=frozenset(sans_dns), + sans_ip=frozenset(sans_ip), + sans_oid=frozenset(sans_oid), + expiry_time=expiry_time, + validity_start_time=validity_start_time, + ) + + +@dataclass(frozen=True) +class CertificateSigningRequest: + """This class represents a certificate signing request.""" + + raw: str + common_name: str + sans_dns: Optional[FrozenSet[str]] = None + sans_ip: Optional[FrozenSet[str]] = None + sans_oid: Optional[FrozenSet[str]] = None + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + + def __eq__(self, other: object) -> bool: + """Check if two CertificateSigningRequest objects are equal.""" + if not isinstance(other, CertificateSigningRequest): + return NotImplemented + return self.raw.strip() == other.raw.strip() + + def __str__(self) -> str: + """Return the CSR as a string.""" + return self.raw + + @classmethod + def from_string(cls, csr: str) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + try: + csr_object = x509.load_pem_x509_csr(csr.encode()) + except ValueError as e: + logger.error("Could not load CSR: %s", e) + raise TLSCertificatesError("Could not load CSR") + common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + state_or_province_name = csr_object.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + try: + sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + sans_dns = frozenset(sans.get_values_for_type(x509.DNSName)) + sans_ip = frozenset([str(san) for san in sans.get_values_for_type(x509.IPAddress)]) + sans_oid = frozenset([str(san) for san in sans.get_values_for_type(x509.RegisteredID)]) + except x509.ExtensionNotFound: + sans = frozenset() + sans_dns = frozenset() + sans_ip = frozenset() + sans_oid = frozenset() + return cls( + raw=csr.strip(), + common_name=str(common_name[0].value), + country_name=str(country_name[0].value) if country_name else None, + state_or_province_name=str(state_or_province_name[0].value) + if state_or_province_name + else None, + locality_name=str(locality_name[0].value) if locality_name else None, + organization=str(organization_name[0].value) if organization_name else None, + email_address=str(email_address[0].value) if email_address else None, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + ) + + def matches_private_key(self, key: PrivateKey) -> bool: + """Check if a CSR matches a private key. + + This function only works with RSA keys. + + Args: + key (PrivateKey): Private key + Returns: + bool: True/False depending on whether the CSR matches the private key. + """ + try: + csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) + key_object = serialization.load_pem_private_key( + data=key.raw.encode("utf-8"), password=None + ) + key_object_public_key = key_object.public_key() + csr_object_public_key = csr_object.public_key() + if not isinstance(key_object_public_key, rsa.RSAPublicKey): + logger.warning("Key is not an RSA key") + return False + if not isinstance(csr_object_public_key, rsa.RSAPublicKey): + logger.warning("CSR is not an RSA key") + return False + if ( + csr_object_public_key.public_numbers().n + != key_object_public_key.public_numbers().n + ): + logger.warning("Public key numbers between CSR and key do not match") + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + def matches_certificate(self, certificate: Certificate) -> bool: + """Check if a CSR matches a certificate. + + Args: + certificate (Certificate): Certificate + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8")) + return csr_object.public_key() == cert_object.public_key() + + def get_sha256_hex(self) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(self.raw.encode()) + return digest.finalize().hex() + + +@dataclass(frozen=True) +class CertificateRequestAttributes: + """A representation of the certificate request attributes. + + This class should be used inside the requirer charm to specify the requested + attributes for the certificate. + """ + + common_name: str + sans_dns: Optional[FrozenSet[str]] = frozenset() + sans_ip: Optional[FrozenSet[str]] = frozenset() + sans_oid: Optional[FrozenSet[str]] = frozenset() + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + is_ca: bool = False + + def is_valid(self) -> bool: + """Check whether the certificate request is valid.""" + if not self.common_name: + return False + return True + + def generate_csr( + self, + private_key: PrivateKey, + ) -> CertificateSigningRequest: + """Generate a CSR using private key and subject. + + Args: + private_key (PrivateKey): Private key + + Returns: + CertificateSigningRequest: CSR + """ + return generate_csr( + private_key=private_key, + common_name=self.common_name, + sans_dns=self.sans_dns, + sans_ip=self.sans_ip, + sans_oid=self.sans_oid, + email_address=self.email_address, + organization=self.organization, + organizational_unit=self.organizational_unit, + country_name=self.country_name, + state_or_province_name=self.state_or_province_name, + locality_name=self.locality_name, + ) + + @classmethod + def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): + """Create a CertificateRequestAttributes object from a CSR.""" + return cls( + common_name=csr.common_name, + sans_dns=csr.sans_dns, + sans_ip=csr.sans_ip, + sans_oid=csr.sans_oid, + email_address=csr.email_address, + organization=csr.organization, + organizational_unit=csr.organizational_unit, + country_name=csr.country_name, + state_or_province_name=csr.state_or_province_name, + locality_name=csr.locality_name, + is_ca=is_ca, + ) + + +@dataclass(frozen=True) +class ProviderCertificate: + """This class represents a certificate provided by the TLS provider.""" + + relation_id: int + certificate: Certificate + certificate_signing_request: CertificateSigningRequest + ca: Certificate + chain: List[Certificate] + revoked: Optional[bool] = None + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "csr": str(self.certificate_signing_request), + "certificate": str(self.certificate), + "ca": str(self.ca), + "chain": [str(cert) for cert in self.chain], + "revoked": self.revoked, + } + ) + + +@dataclass(frozen=True) +class RequirerCertificateRequest: + """This class represents a certificate signing request requested by a specific TLS requirer.""" + + relation_id: int + certificate_signing_request: CertificateSigningRequest + is_ca: bool + + +class CertificateAvailableEvent(EventBase): + """Charm Event triggered when a TLS certificate is available.""" + + def __init__( + self, + handle: Handle, + certificate: Certificate, + certificate_signing_request: CertificateSigningRequest, + ca: Certificate, + chain: List[Certificate], + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": str(self.certificate), + "certificate_signing_request": str(self.certificate_signing_request), + "ca": str(self.ca), + "chain": json.dumps([str(certificate) for certificate in self.chain]), + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = Certificate.from_string(snapshot["certificate"]) + self.certificate_signing_request = CertificateSigningRequest.from_string( + snapshot["certificate_signing_request"] + ) + self.ca = Certificate.from_string(snapshot["ca"]) + chain_strs = json.loads(snapshot["chain"]) + self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join([str(cert) for cert in self.chain]) + + +def generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> PrivateKey: + """Generate a private key with the RSA algorithm. + + Args: + key_size (int): Key size in bytes + public_exponent: Public exponent. + + Returns: + PrivateKey: Private Key + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return PrivateKey.from_string(key_bytes.decode()) + + +def generate_csr( # noqa: C901 + private_key: PrivateKey, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + locality_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + add_unique_id_to_subject_name: bool = True, +) -> CertificateSigningRequest: + """Generate a CSR using private key and subject. + + Args: + private_key (PrivateKey): Private key + common_name (str): Common name + sans_dns (FrozenSet[str]): DNS Subject Alternative Names + sans_ip (FrozenSet[str]): IP Subject Alternative Names + sans_oid (FrozenSet[str]): OID Subject Alternative Names + organization (Optional[str]): Organization name + organizational_unit (Optional[str]): Organizational unit name + email_address (Optional[str]): Email address + country_name (Optional[str]): Country name + state_or_province_name (Optional[str]): State or province name + locality_name (Optional[str]): Locality name + add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's + subject name. Always leave to "True" when the CSR is used to request certificates + using the tls-certificates relation. + + Returns: + CertificateSigningRequest: CSR + """ + signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None) + subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] + if add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if organization: + subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) + if organizational_unit: + subject_name.append( + x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit) + ) + if email_address: + subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) + if country_name: + subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) + csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) + + _sans: List[x509.GeneralName] = [] + if sans_oid: + _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) + if sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) + if sans_dns: + _sans.extend([x509.DNSName(san) for san in sans_dns]) + if _sans: + csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) + signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] + csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode() + return CertificateSigningRequest.from_string(csr_str) + + +def generate_ca( + private_key: PrivateKey, + validity: timedelta, + common_name: str, + sans_dns: Optional[FrozenSet[str]] = frozenset(), + sans_ip: Optional[FrozenSet[str]] = frozenset(), + sans_oid: Optional[FrozenSet[str]] = frozenset(), + organization: Optional[str] = None, + organizational_unit: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, +) -> Certificate: + """Generate a self signed CA Certificate. + + Args: + private_key (PrivateKey): Private key + validity (timedelta): Certificate validity time + common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + sans_dns (FrozenSet[str]): DNS Subject Alternative Names + sans_ip (FrozenSet[str]): IP Subject Alternative Names + sans_oid (FrozenSet[str]): OID Subject Alternative Names + organization (Optional[str]): Organization name + organizational_unit (Optional[str]): Organizational unit name + email_address (Optional[str]): Email address + country_name (str): Certificate Issuing country + state_or_province_name (str): Certificate Issuing state or province + locality_name (str): Certificate Issuing locality + + Returns: + Certificate: CA Certificate. + """ + private_key_object = serialization.load_pem_private_key( + str(private_key).encode(), password=None + ) + assert isinstance(private_key_object, rsa.RSAPrivateKey) + subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] + if organization: + subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) + if organizational_unit: + subject_name.append( + x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit) + ) + if email_address: + subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) + if country_name: + subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) + + _sans: List[x509.GeneralName] = [] + if sans_oid: + _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) + if sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) + if sans_dns: + _sans.extend([x509.DNSName(san) for san in sans_dns]) + + subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( + private_key_object.public_key() + ) + subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + cert = ( + x509.CertificateBuilder() + .subject_name(x509.Name(subject_name)) + .issuer_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)])) + .public_key(private_key_object.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + validity) + .add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) + .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) + .add_extension( + x509.AuthorityKeyIdentifier( + key_identifier=key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ) + .add_extension(key_usage, critical=True) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] + ) + ca_cert_str = cert.public_bytes(serialization.Encoding.PEM).decode().strip() + return Certificate.from_string(ca_cert_str) + + +def generate_certificate( + csr: CertificateSigningRequest, + ca: Certificate, + ca_private_key: PrivateKey, + validity: timedelta, + is_ca: bool = False, +) -> Certificate: + """Generate a TLS certificate based on a CSR. + + Args: + csr (CertificateSigningRequest): CSR + ca (Certificate): CA Certificate + ca_private_key (PrivateKey): CA private key + validity (timedelta): Certificate validity time + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + Certificate: Certificate + """ + csr_object = x509.load_pem_x509_csr(str(csr).encode()) + subject = csr_object.subject + ca_pem = x509.load_pem_x509_certificate(str(ca).encode()) + issuer = ca_pem.issuer + private_key = serialization.load_pem_private_key(str(ca_private_key).encode(), password=None) + + certificate_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(csr_object.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + validity) + ) + extensions = _get_certificate_request_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + is_ca=is_ca, + ) + for extension in extensions: + try: + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, + ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) + + cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] + cert_bytes = cert.public_bytes(serialization.Encoding.PEM) + return Certificate.from_string(cert_bytes.decode().strip()) + + +def _get_certificate_request_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + sans: List[x509.GeneralName] = [] + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + +class CertificatesRequirerCharmEvents(CharmEvents): + """List of events that the TLS Certificates requirer charm can leverage.""" + + certificate_available = EventSource(CertificateAvailableEvent) + + +class TLSCertificatesRequiresV4(Object): + """A class to manage the TLS certificates interface for a unit or app.""" + + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + + def __init__( + self, + charm: CharmBase, + relationship_name: str, + certificate_requests: List[CertificateRequestAttributes], + mode: Mode = Mode.UNIT, + refresh_events: List[BoundEvent] = [], + ): + """Create a new instance of the TLSCertificatesRequiresV4 class. + + Args: + charm (CharmBase): The charm instance to relate to. + relationship_name (str): The name of the relation that provides the certificates. + certificate_requests (List[CertificateRequestAttributes]): + A list with the attributes of the certificate requests. + mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. + refresh_events (List[BoundEvent]): A list of events to trigger a refresh of + the certificates. + """ + super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") + if not self._mode_is_valid(mode): + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") + for certificate_request in certificate_requests: + if not certificate_request.is_valid(): + raise TLSCertificatesError("Invalid certificate request") + self.charm = charm + self.relationship_name = relationship_name + self.certificate_requests = certificate_requests + self.mode = mode + self.framework.observe(charm.on[relationship_name].relation_created, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + for event in refresh_events: + self.framework.observe(event, self._configure) + + def _configure(self, _: EventBase): + """Handle TLS Certificates Relation Data. + + This method is called during any TLS relation event. + It will generate a private key if it doesn't exist yet. + It will send certificate requests if they haven't been sent yet. + It will find available certificates and emit events. + """ + if not self._tls_relation_created(): + logger.debug("TLS relation not created yet.") + return + self._generate_private_key() + self._send_certificate_requests() + self._find_available_certificates() + self._cleanup_certificate_requests() + + def _mode_is_valid(self, mode) -> bool: + return mode in [Mode.UNIT, Mode.APP] + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Renews certificate requests and removes the expired secret. + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): + return + try: + csr_str = event.secret.get_content(refresh=True)["csr"] + except ModelError: + logger.error("Failed to get CSR from secret - Skipping") + return + csr = CertificateSigningRequest.from_string(csr_str) + self._renew_certificate_request(csr) + event.secret.remove_all_revisions() + + def renew_certificate(self, certificate: ProviderCertificate) -> None: + """Request the renewal of the provided certificate.""" + certificate_signing_request = certificate.certificate_signing_request + secret_label = self._get_csr_secret_label(certificate_signing_request) + try: + secret = self.model.get_secret(label=secret_label) + except SecretNotFoundError: + logger.warning("No matching secret found - Skipping renewal") + return + current_csr = secret.get_content(refresh=True).get("csr", "") + if current_csr != str(certificate_signing_request): + logger.warning("No matching CSR found - Skipping renewal") + return + self._renew_certificate_request(certificate_signing_request) + secret.remove_all_revisions() + + def _renew_certificate_request(self, csr: CertificateSigningRequest): + """Remove existing CSR from relation data and create a new one.""" + self._remove_requirer_csr_from_relation_data(csr) + self._send_certificate_requests() + logger.info("Renewed certificate request") + + def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + if not self.get_csrs_from_requirer_relation_data(): + logger.info("No CSRs in relation data - Doing nothing") + return + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data - Skipping removal of CSR") + return + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + for requirer_csr in new_relation_data: + if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): + new_relation_data.remove(requirer_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Removed CSR from relation data") + except ModelError: + logger.warning("Failed to update relation data") + + def _get_app_or_unit(self) -> Union[Application, Unit]: + """Return the unit or app object based on the mode.""" + if self.mode == Mode.UNIT: + return self.model.unit + elif self.mode == Mode.APP: + return self.model.app + raise TLSCertificatesError("Invalid mode") + + @property + def private_key(self) -> PrivateKey | None: + """Return the private key.""" + if not self._private_key_generated(): + return None + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + private_key = secret.get_content(refresh=True)["private-key"] + return PrivateKey.from_string(private_key) + + def _generate_private_key(self) -> None: + if self._private_key_generated(): + return + private_key = generate_private_key() + self.charm.unit.add_secret( + content={"private-key": str(private_key)}, + label=self._get_private_key_secret_label(), + ) + logger.info("Private key generated") + + def regenerate_private_key(self) -> None: + """Regenerate the private key. + + Generate a new private key, remove old certificate requests and send new ones. + """ + if not self._private_key_generated(): + logger.warning("No private key to regenerate") + return + self._regenerate_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + + def _regenerate_private_key(self) -> None: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.set_content({"private-key": str(generate_private_key())}) + + def _private_key_generated(self) -> bool: + try: + self.charm.model.get_secret(label=self._get_private_key_secret_label()) + except (SecretNotFoundError, KeyError): + return False + return True + + def _csr_matches_certificate_request( + self, certificate_signing_request: CertificateSigningRequest, is_ca: bool + ) -> bool: + for certificate_request in self.certificate_requests: + if certificate_request == CertificateRequestAttributes.from_csr( + certificate_signing_request, + is_ca, + ): + return True + return False + + def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: + if not self.private_key: + return False + csr = self._certificate_requested_for_attributes(certificate_request) + if not csr: + return False + if not csr.certificate_signing_request.matches_private_key(key=self.private_key): + return False + return True + + def _certificate_requested_for_attributes( + self, + certificate_request: CertificateRequestAttributes, + ) -> Optional[RequirerCertificateRequest]: + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return requirer_csr + return None + + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: + """Return list of requirer's CSRs from relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return [] + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + requirer_csrs = [] + for csr in requirer_relation_data.certificate_signing_requests: + requirer_csrs.append( + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + ) + return requirer_csrs + + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + return self._load_provider_certificates() + + def _load_provider_certificates(self) -> List[ProviderCertificate]: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + if not relation.app: + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + return [ + certificate.to_provider_certificate(relation_id=relation.id) + for certificate in provider_relation_data.certificates + ] + + def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: + """Add CSR to relation data.""" + if self.mode == Mode.APP and not self.model.unit.is_leader(): + logger.debug("Not a leader unit - Skipping") + return + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + new_csr = _CertificateSigningRequest( + certificate_signing_request=str(csr).strip(), ca=is_ca + ) + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + requirer_relation_data = _RequirerData( + certificate_signing_requests=[], + ) + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + new_relation_data.append(new_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Certificate signing request added to relation data.") + except ModelError: + logger.warning("Failed to update relation data") + + def _send_certificate_requests(self): + if not self.private_key: + logger.debug("Private key not generated yet.") + return + for certificate_request in self.certificate_requests: + if not self._certificate_requested(certificate_request): + csr = certificate_request.generate_csr( + private_key=self.private_key, + ) + if not csr: + logger.warning("Failed to generate CSR") + continue + self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) + + def get_assigned_certificate( + self, certificate_request: CertificateRequestAttributes + ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: + """Get the certificate that was assigned to the given certificate request.""" + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == CertificateRequestAttributes.from_csr( + requirer_csr.certificate_signing_request, + requirer_csr.is_ca, + ): + return self._find_certificate_in_relation_data(requirer_csr), self.private_key + return None, None + + def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]: + """Get a list of certificates that were assigned to this or app.""" + assigned_certificates = [] + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if cert := self._find_certificate_in_relation_data(requirer_csr): + assigned_certificates.append(cert) + return assigned_certificates, self.private_key + + def _find_certificate_in_relation_data( + self, csr: RequirerCertificateRequest + ) -> Optional[ProviderCertificate]: + """Return the certificate that match the given CSR.""" + for provider_certificate in self.get_provider_certificates(): + if ( + provider_certificate.certificate_signing_request == csr.certificate_signing_request + and provider_certificate.certificate.is_ca == csr.is_ca + ): + return provider_certificate + return None + + def _find_available_certificates(self): + """Find available certificates and emit events. + + This method will find certificates that are available for the requirer's CSRs. + If a certificate is found, it will be set as a secret and an event will be emitted. + If a certificate is revoked, the secret will be removed and an event will be emitted. + """ + requirer_csrs = self.get_csrs_from_requirer_relation_data() + csrs = [csr.certificate_signing_request for csr in requirer_csrs] + provider_certificates = self.get_provider_certificates() + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request in csrs: + secret_label = self._get_csr_secret_label( + provider_certificate.certificate_signing_request + ) + if provider_certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + secret_label, + ) + secret = self.model.get_secret(label=secret_label) + secret.remove_all_revisions() + else: + if not self._csr_matches_certificate_request( + certificate_signing_request=provider_certificate.certificate_signing_request, + is_ca=provider_certificate.certificate.is_ca, + ): + logger.debug("Certificate requested for different attributes - Skipping") + continue + try: + secret = self.model.get_secret(label=secret_label) + logger.debug("Setting secret with label %s", secret_label) + # Juju < 3.6 will create a new revision even if the content is the same + if secret.get_content(refresh=True).get("certificate", "") == str( + provider_certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", secret_label + ) + return + secret.set_content( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + } + ) + secret.set_info( + expire=provider_certificate.certificate.expiry_time, + ) + except SecretNotFoundError: + logger.debug("Creating new secret with label %s", secret_label) + secret = self.charm.unit.add_secret( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + }, + label=secret_label, + expire=provider_certificate.certificate.expiry_time, + ) + self.on.certificate_available.emit( + certificate_signing_request=provider_certificate.certificate_signing_request, + certificate=provider_certificate.certificate, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + + def _cleanup_certificate_requests(self): + """Clean up certificate requests. + + Remove any certificate requests that falls into one of the following categories: + - The CSR attributes do not match any of the certificate requests defined in + the charm's certificate_requests attribute. + - The CSR public key does not match the private key. + """ + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if not self._csr_matches_certificate_request( + certificate_signing_request=requirer_csr.certificate_signing_request, + is_ca=requirer_csr.is_ca, + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match any certificate request" # noqa: E501 + ) + elif ( + self.private_key + and not requirer_csr.certificate_signing_request.matches_private_key( + self.private_key + ) + ): + self._remove_requirer_csr_from_relation_data( + requirer_csr.certificate_signing_request + ) + logger.info( + "Removed CSR from relation data because it did not match the private key" # noqa: E501 + ) + + def _tls_relation_created(self) -> bool: + relation = self.model.get_relation(self.relationship_name) + if not relation: + return False + return True + + def _get_private_key_secret_label(self) -> str: + if self.mode == Mode.UNIT: + return f"{LIBID}-private-key-{self._get_unit_number()}" + elif self.mode == Mode.APP: + return f"{LIBID}-private-key" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: + csr_in_sha256_hex = csr.get_sha256_hex() + if self.mode == Mode.UNIT: + return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" + elif self.mode == Mode.APP: + return f"{LIBID}-certificate-{csr_in_sha256_hex}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_unit_number(self) -> str: + return self.model.unit.name.split("/")[1] + + +class TLSCertificatesProvidesV4(Object): + """TLS certificates provider class to be instantiated by TLS certificates providers.""" + + def __init__(self, charm: CharmBase, relationship_name: str): + super().__init__(charm, relationship_name) + self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.update_status, self._configure) + self.charm = charm + self.relationship_name = relationship_name + + def _configure(self, _: EventBase) -> None: + """Handle update status and tls relation changed events. + + This is a common hook triggered on a regular basis. + + Revoke certificates for which no csr exists + """ + if not self.model.unit.is_leader(): + return + self._remove_certificates_for_which_no_csr_exists() + + def _remove_certificates_for_which_no_csr_exists(self) -> None: + provider_certificates = self.get_provider_certificates() + requirer_csrs = [ + request.certificate_signing_request for request in self.get_certificate_requests() + ] + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request not in requirer_csrs: + tls_relation = self._get_tls_relations( + relation_id=provider_certificate.relation_id + ) + self._remove_provider_certificate( + certificate=provider_certificate.certificate, + relation=tls_relation[0], + ) + + def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: + return ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + def get_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Load certificate requests from the relation data.""" + relations = self._get_tls_relations(relation_id) + requirer_csrs: List[RequirerCertificateRequest] = [] + for relation in relations: + for unit in relation.units: + requirer_csrs.extend(self._load_requirer_databag(relation, unit)) + requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) + return requirer_csrs + + def _load_requirer_databag( + self, relation: Relation, unit_or_app: Union[Application, Unit] + ) -> List[RequirerCertificateRequest]: + try: + requirer_relation_data = _RequirerData.load(relation.data[unit_or_app]) + except DataValidationError: + logger.debug("Invalid requirer relation data for %s", unit_or_app.name) + return [] + return [ + RequirerCertificateRequest( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + is_ca=csr.ca if csr.ca else False, + ) + for csr in requirer_relation_data.certificate_signing_requests + ] + + def _add_provider_certificate( + self, + relation: Relation, + provider_certificate: ProviderCertificate, + ) -> None: + new_certificate = _Certificate( + certificate=str(provider_certificate.certificate), + certificate_signing_request=str(provider_certificate.certificate_signing_request), + ca=str(provider_certificate.ca), + chain=[str(certificate) for certificate in provider_certificate.chain], + ) + provider_certificates = self._load_provider_certificates(relation) + if new_certificate in provider_certificates: + logger.info("Certificate already in relation data - Doing nothing") + return + provider_certificates.append(new_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) + except DataValidationError: + logger.debug("Invalid provider relation data") + return [] + return copy.deepcopy(provider_relation_data.certificates) + + def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): + try: + _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) + logger.info("Certificate relation data updated") + except ModelError: + logger.warning("Failed to update relation data") + + def _remove_provider_certificate( + self, + relation: Relation, + certificate: Optional[Certificate] = None, + certificate_signing_request: Optional[CertificateSigningRequest] = None, + ) -> None: + """Remove certificate based on certificate or certificate signing request.""" + provider_certificates = self._load_provider_certificates(relation) + for provider_certificate in provider_certificates: + if certificate and provider_certificate.certificate == str(certificate): + provider_certificates.remove(provider_certificate) + if ( + certificate_signing_request + and provider_certificate.certificate_signing_request + == str(certificate_signing_request) + ): + provider_certificates.remove(provider_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def revoke_all_certificates(self) -> None: + """Revoke all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + relations = self._get_tls_relations() + for relation in relations: + provider_certificates = self._load_provider_certificates(relation) + for certificate in provider_certificates: + certificate.revoked = True + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def set_relation_certificate( + self, + provider_certificate: ProviderCertificate, + ) -> None: + """Add certificates to relation data. + + Args: + provider_certificate (ProviderCertificate): ProviderCertificate object + + Returns: + None + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=provider_certificate.relation_id + ) + if not certificates_relation: + raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") + self._remove_provider_certificate( + relation=certificates_relation, + certificate_signing_request=provider_certificate.certificate_signing_request, + ) + self._add_provider_certificate( + relation=certificates_relation, + provider_certificate=provider_certificate, + ) + + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not read relation data") + return [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates.""" + certificates: List[ProviderCertificate] = [] + relations = self._get_tls_relations(relation_id) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + for certificate in self._load_provider_certificates(relation): + certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) + return certificates + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.certificate_signing_request not in list_of_csrs: + unsolicited_certificates.append(certificate) + return unsolicited_certificates + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: + """Return CSR's for which no certificate has been issued. + + Args: + relation_id (int): Relation id + + Returns: + list: List of RequirerCertificateRequest objects. + """ + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + outstanding_csrs: List[RequirerCertificateRequest] = [] + for relation_csr in requirer_csrs: + if not self._certificate_issued_for_csr( + csr=relation_csr.certificate_signing_request, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def _certificate_issued_for_csr( + self, csr: CertificateSigningRequest, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR.""" + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.certificate_signing_request == csr: + return csr.matches_certificate(issued_certificate.certificate) + return False diff --git a/src/charm.py b/src/charm.py index b0364391..e68fece0 100755 --- a/src/charm.py +++ b/src/charm.py @@ -5,6 +5,7 @@ """Charmed operator for the Aether SD-Core Graphical User Interface for K8s.""" import logging +import socket from ipaddress import IPv4Address from subprocess import CalledProcessError, check_output from typing import List, Optional @@ -26,19 +27,21 @@ CollectStatusEvent, ModelError, WaitingStatus, + main, ) from ops.charm import CharmBase from ops.framework import EventBase -from ops.main import main from ops.pebble import Layer from nms import NMS, GnodeB, Upf +from tls import CA_CERTIFICATE_NAME, Tls logger = logging.getLogger(__name__) BASE_CONFIG_PATH = "/nms/config" CONFIG_FILE_NAME = "nmscfg.conf" NMS_CONFIG_PATH = f"{BASE_CONFIG_PATH}/{CONFIG_FILE_NAME}" +CERTS_MOUNT_PATH = "/support/TLS" WORKLOAD_VERSION_FILE_NAME = "/etc/workload-version" AUTH_DATABASE_RELATION_NAME = "auth_database" COMMON_DATABASE_RELATION_NAME = "common_database" @@ -50,7 +53,13 @@ COMMON_DATABASE_NAME = "free5gc" GRPC_PORT = 9876 NMS_URL_PORT = 5000 - +TLS_RELATION_NAME = "certificates" +MANDATORY_RELATIONS = [ + COMMON_DATABASE_RELATION_NAME, + AUTH_DATABASE_RELATION_NAME, + TLS_RELATION_NAME, +] +CA_CERTIFICATE_CHARM_PATH = f"/var/lib/juju/storage/certs/0/{CA_CERTIFICATE_NAME}" def _get_pod_ip() -> Optional[str]: """Return the pod IP using juju client.""" @@ -61,23 +70,6 @@ def _get_pod_ip() -> Optional[str]: return None -def render_config_file( - common_database_name: str, - common_database_url: str, - auth_database_name: str, - auth_database_url: str, -) -> str: - """Render nms configuration file based on Jinja template.""" - jinja2_environment = Environment(loader=FileSystemLoader("src/templates/")) - template = jinja2_environment.get_template("nmscfg.conf.j2") - return template.render( - common_database_name=common_database_name, - common_database_url=common_database_url, - auth_database_name=auth_database_name, - auth_database_url=auth_database_url, - ) - - class SDCoreNMSOperatorCharm(CharmBase): """Main class to describe juju event handling for the Aether SD-Core NMS operator for K8s.""" @@ -93,6 +85,13 @@ def __init__(self, *args): return self._container_name = self._service_name = "nms" self._container = self.unit.get_container(self._container_name) + self._tls = Tls( + charm=self, + relation_name=TLS_RELATION_NAME, + container=self._container, + domain_name=socket.getfqdn(), + workload_storage_path= CERTS_MOUNT_PATH + ) self._common_database = DatabaseRequires( self, relation_name=COMMON_DATABASE_RELATION_NAME, @@ -111,7 +110,9 @@ def __init__(self, *args): port=NMS_URL_PORT, relation_name="ingress", strip_prefix=True, + scheme=lambda: "https", ) + self.fiveg_n4 = N4Requires(charm=self, relation_name=FIVEG_N4_RELATION_NAME) self._gnb_identity = GnbIdentityRequires(self, GNB_IDENTITY_RELATION_NAME) self._logging = LogForwarder(charm=self, relation_name=LOGGING_RELATION_NAME) @@ -144,9 +145,20 @@ def __init__(self, *args): self.on[FIVEG_N4_RELATION_NAME].relation_broken, self._configure_sdcore_nms, ) + self.framework.observe(self.on.certificates_relation_joined, self._configure_sdcore_nms) + self.framework.observe( + self.on.certificates_relation_broken, self._on_certificates_relation_broken + ) + self.framework.observe( + self._tls._certificates.on.certificate_available, self._configure_sdcore_nms + ) # Handling config changed event to publish the new url if the unit reboots and gets new IP self.framework.observe(self.on.config_changed, self._configure_sdcore_nms) - self._nms = NMS(url=f"http://{self._nms_endpoint}") + self._nms = NMS( + url=f"https://{socket.getfqdn()}:{NMS_URL_PORT}", + ca_certificate_path=CA_CERTIFICATE_CHARM_PATH + ) + def _configure_sdcore_nms(self, event: EventBase) -> None: """Handle Juju events. @@ -159,13 +171,18 @@ def _configure_sdcore_nms(self, event: EventBase) -> None: return if not self._container.exists(path=BASE_CONFIG_PATH): return - for relation in [COMMON_DATABASE_RELATION_NAME, AUTH_DATABASE_RELATION_NAME]: + if not self._container.exists(path=CERTS_MOUNT_PATH): + return + for relation in MANDATORY_RELATIONS: if not self._relation_created(relation): return if not self._common_database_resource_is_available(): return if not self._auth_database_resource_is_available(): return + if not self._tls.certificate_is_available(): + logger.info("The TLS certificate is not available yet.") + return self._configure_workload() self._publish_sdcore_config_url() self._sync_gnbs() @@ -185,7 +202,7 @@ def _on_collect_unit_status(self, event: CollectStatusEvent): # noqa: C901 event.add_status(BlockedStatus("Scaling is not implemented for this charm")) logger.info("Scaling is not implemented for this charm") return - for relation in [COMMON_DATABASE_RELATION_NAME, AUTH_DATABASE_RELATION_NAME]: + for relation in MANDATORY_RELATIONS: if not self._relation_created(relation): event.add_status(BlockedStatus(f"Waiting for {relation} relation to be created")) logger.info("Waiting for %s relation to be created", relation) @@ -204,14 +221,18 @@ def _on_collect_unit_status(self, event: CollectStatusEvent): # noqa: C901 return self.unit.set_workload_version(self._get_workload_version()) - if not self._container.exists(path=BASE_CONFIG_PATH): + if (not self._container.exists(path=BASE_CONFIG_PATH) or + not self._container.exists(path=CERTS_MOUNT_PATH)): event.add_status(WaitingStatus("Waiting for storage to be attached")) logger.info("Waiting for storage to be attached") return if not self._nms_config_file_exists(): - event.add_status(WaitingStatus("Waiting for nms config file to be stored")) - logger.info("Waiting for nms config file to be stored") + event.add_status(WaitingStatus("Waiting for NMS config file to be stored")) + logger.info("Waiting for NMS config file to be stored") return + if not self._tls.certificate_is_available(): + event.add_status(WaitingStatus("Waiting for certificates to be available")) + logger.info("Waiting for certificates to be available") if not self._is_nms_service_running(): event.add_status(WaitingStatus("Waiting for NMS service to start")) logger.info("Waiting for NMS service to start") @@ -219,6 +240,13 @@ def _on_collect_unit_status(self, event: CollectStatusEvent): # noqa: C901 event.add_status(ActiveStatus()) + def _on_certificates_relation_broken(self, event: EventBase) -> None: + """Delete TLS related artifacts.""" + if not self._container.can_connect(): + event.defer() + return + self._tls.clean_up_certificates() + def _publish_sdcore_config_url(self) -> None: if not self._relation_created(SDCORE_CONFIG_RELATION_NAME): return @@ -227,8 +255,10 @@ def _publish_sdcore_config_url(self) -> None: self._sdcore_config.set_webui_url_in_all_relations(webui_url=self._nms_config_url) def _configure_workload(self): + certificate_update_required = self._tls.check_and_update_certificate() desired_config_file = self._generate_nms_config_file() - if not self._is_config_file_update_required(desired_config_file): + if (not self._is_config_file_update_required(desired_config_file) + and not certificate_update_required): self._configure_pebble() return self._write_file_in_workload(NMS_CONFIG_PATH, desired_config_file) @@ -254,11 +284,16 @@ def _nms_config_file_exists(self) -> bool: return bool(self._container.exists(NMS_CONFIG_PATH)) def _generate_nms_config_file(self) -> str: - return render_config_file( + """Render nms configuration file based on Jinja template.""" + jinja2_environment = Environment(loader=FileSystemLoader("src/templates/")) + template = jinja2_environment.get_template("nmscfg.conf.j2") + return template.render( common_database_name=COMMON_DATABASE_NAME, common_database_url=self._get_common_database_url(), auth_database_name=AUTH_DATABASE_NAME, auth_database_url=self._get_auth_database_url(), + tls_key_path=self._tls.private_key_workload_path, + tls_certificate_path=self._tls.certificate_workload_path, ) def _is_nms_service_running(self) -> bool: @@ -389,7 +424,7 @@ def _pebble_layer(self) -> Layer: "nms": { "override": "replace", "startup": "enabled", - "command": f"/bin/webconsole --webuicfg {NMS_CONFIG_PATH}", # noqa: E501 + "command": f"/bin/webconsole --cfg {NMS_CONFIG_PATH}", # noqa: E501 "environment": self._environment_variables, }, }, diff --git a/src/nms.py b/src/nms.py index 05d59f89..518cd67e 100644 --- a/src/nms.py +++ b/src/nms.py @@ -18,7 +18,6 @@ JSON_HEADER = {"Content-Type": "application/json"} - @dataclass class GnodeB: """Class to represent a gNB.""" @@ -52,10 +51,11 @@ class CreateUPFParams: class NMS: """Handle NMS API calls.""" - def __init__(self, url: str): + def __init__(self, url: str, ca_certificate_path: str = ""): if url.endswith("/"): url = url[:-1] self.url = url + self._ca_certificate_path = ca_certificate_path def _make_request( self, @@ -72,7 +72,11 @@ def _make_request( url=url, headers=headers, json=data, + verify=self._ca_certificate_path or False ) + except requests.exceptions.SSLError as e: + logger.error("SSL error: %s", e) + return None except requests.RequestException as e: logger.error("HTTP request failed: %s", e) return None @@ -115,7 +119,7 @@ def create_gnb(self, name: str, tac: int) -> None: def delete_gnb(self, name: str) -> None: """Delete a gNB list from the NMS inventory.""" self._make_request("DELETE", f"/{GNB_CONFIG_URL}/{name}") - logger.info("UPF %s deleted from NMS", name) + logger.info("gNB %s deleted from NMS", name) def list_upfs(self) -> List[Upf]: """List UPFs from the NMS inventory.""" diff --git a/src/templates/nmscfg.conf.j2 b/src/templates/nmscfg.conf.j2 index 74c76f4d..36c253a9 100644 --- a/src/templates/nmscfg.conf.j2 +++ b/src/templates/nmscfg.conf.j2 @@ -8,9 +8,13 @@ configuration: authKeysDbName: {{ auth_database_name }} authUrl: {{ auth_database_url }} spec-compliant-sdf: false + tls: + key: {{ tls_key_path }} + pem: {{ tls_certificate_path }} info: description: WebUI initial local configuration version: 1.0.0 + http-version: 2 logger: AMF: ReportCaller: false diff --git a/src/tls.py b/src/tls.py new file mode 100644 index 00000000..7dd453be --- /dev/null +++ b/src/tls.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Module use to handle TLS certificates for the NMS.""" + +import logging +from typing import Optional + +from charms.tls_certificates_interface.v4.tls_certificates import ( + Certificate, + CertificateRequestAttributes, + PrivateKey, + TLSCertificatesRequiresV4, +) +from ops import CharmBase, Container + +logger = logging.getLogger(__name__) + + +CERTIFICATE_COMMON_NAME = "nms.sdcore" +PRIVATE_KEY_NAME = "nms.key" +CERTIFICATE_NAME = "nms.pem" +CA_CERTIFICATE_NAME = "ca.pem" + + +class Tls: + """Handle TLS certificates.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + container: Container, + domain_name: str, + workload_storage_path: str + ): + self._storage_path = workload_storage_path + self._domain_name = domain_name + self._container = container + self._certificates = TLSCertificatesRequiresV4( + charm=charm, + relationship_name=relation_name, + certificate_requests=[self._get_certificate_request()], + ) + + def certificate_is_available(self) -> bool: + """Check if a valid certificate and private key are available. + + Returns: + bool: True if both the certificate and private key are available, + False otherwise. + """ + cert, key = self._certificates.get_assigned_certificate( + certificate_request=self._get_certificate_request() + ) + return bool(cert and key) + + def check_and_update_certificate(self) -> bool: + """Check if the certificate, CA certificate or private key needs an update and update it. + + This method retrieves the currently assigned certificate, CA certificate and private key + associated with the charm's TLS relation. It checks whether the certificate, + CA certificate, or private key has changed or needs updating. + If an update is necessary, the new certificates or private key is stored. + + Returns: + bool: True if either the certificate, CA certificate or the private key was updated. + False otherwise. + """ + provider_certificate, private_key = self._certificates.get_assigned_certificate( + certificate_request=self._get_certificate_request() + ) + if not provider_certificate or not private_key: + logger.debug("Certificate, CA certificate or private key is not available") + return False + if certificate_was_updated := self._is_certificate_update_required( + provider_certificate.certificate + ): + self._store_certificate(certificate=provider_certificate.certificate) + if ca_certificate_was_updated := self._is_ca_certificate_update_required( + provider_certificate.ca + ): + self._store_ca_certificate(ca=provider_certificate.ca) + if private_key_was_updated := self._is_private_key_update_required(private_key): + self._store_private_key(private_key=private_key) + return certificate_was_updated or ca_certificate_was_updated or private_key_was_updated + + def clean_up_certificates(self) -> None: + """Remove all certificate-related files from storage.""" + self._delete_private_key() + self._delete_certificate() + self._delete_ca_certificate() + + def _is_certificate_update_required(self, certificate: Certificate) -> bool: + return self._get_existing_certificate() != certificate + + def _is_private_key_update_required(self, private_key: PrivateKey) -> bool: + return self._get_existing_private_key() != private_key + + def _is_ca_certificate_update_required(self, certificate: Certificate) -> bool: + return self._get_existing_ca_certificate() != certificate + + def _get_existing_certificate(self) -> Optional[Certificate]: + return self._get_stored_certificate() if self._certificate_is_stored() else None + + def _get_existing_private_key(self) -> Optional[PrivateKey]: + return self._get_stored_private_key() if self._private_key_is_stored() else None + + def _get_existing_ca_certificate(self) -> Optional[Certificate]: + return self._get_stored_ca_certificate() if self._ca_certificate_is_stored() else None + + def _delete_certificate(self) -> None: + if not self._certificate_is_stored(): + return + self._container.remove_path(path=self.certificate_workload_path) + logger.info("Removed certificate from workload") + + def _delete_private_key(self) -> None: + if not self._private_key_is_stored(): + return + self._container.remove_path(path=self.private_key_workload_path) + logger.info("Removed private key from workload") + + def _delete_ca_certificate(self) -> None: + if not self._ca_certificate_is_stored(): + return + self._container.remove_path(path=self.ca_certificate_workload_path) + logger.info("Removed CA certificate from workload") + + def _get_stored_certificate(self) -> Certificate: + cert_string = str(self._container.pull(path=self.certificate_workload_path).read()) + return Certificate.from_string(cert_string) + + def _get_stored_private_key(self) -> PrivateKey: + key_string = str(self._container.pull(path=self.private_key_workload_path).read()) + return PrivateKey.from_string(key_string) + + def _get_stored_ca_certificate(self) -> Certificate: + cert_string = str(self._container.pull(path=self.ca_certificate_workload_path).read()) + return Certificate.from_string(cert_string) + + def _certificate_is_stored(self) -> bool: + return self._container.exists(path=self.certificate_workload_path) + + def _private_key_is_stored(self) -> bool: + return self._container.exists(path=self.private_key_workload_path) + + def _ca_certificate_is_stored(self) -> bool: + return self._container.exists(path=self.ca_certificate_workload_path) + + def _store_certificate(self, certificate: Certificate) -> None: + self._container.push(path=self.certificate_workload_path, source=str(certificate)) + logger.info("Pushed certificate to workload") + + def _store_private_key(self, private_key: PrivateKey) -> None: + self._container.push(path=self.private_key_workload_path, source=str(private_key)) + logger.info("Pushed private key to workload") + + def _store_ca_certificate(self, ca: Certificate) -> None: + self._container.push(path=self.ca_certificate_workload_path, source=str(ca)) + logger.info("Pushed CA certificate to workload") + + def _get_certificate_request(self) -> CertificateRequestAttributes: + return CertificateRequestAttributes( + common_name=CERTIFICATE_COMMON_NAME, + sans_dns=frozenset([self._domain_name]), + ) + + @property + def certificate_workload_path(self) -> str: + """Path to the certificate file in the workload storage.""" + return f"{self._storage_path}/{CERTIFICATE_NAME}" + + @property + def private_key_workload_path(self) -> str: + """Path to the private key file in the workload storage.""" + return f"{self._storage_path}/{PRIVATE_KEY_NAME}" + + @property + def ca_certificate_workload_path(self) -> str: + """Path to the CA certificate file in the workload storage.""" + return f"{self._storage_path}/{CA_CERTIFICATE_NAME}" diff --git a/terraform/outputs.tf b/terraform/outputs.tf index f3b6de7f..babcf58a 100644 --- a/terraform/outputs.tf +++ b/terraform/outputs.tf @@ -10,6 +10,7 @@ output "requires" { value = { auth_database = "auth_database" common_database = "common_database" + certificates = "certificates" fiveg_gnb_identity = "fiveg_gnb_identity" fiveg_n4 = "fiveg_n4" ingress = "ingress" diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 9eb93dda..45f65f4e 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -7,7 +7,6 @@ import time from collections import Counter from pathlib import Path -from typing import List import pytest import requests @@ -72,6 +71,14 @@ async def _deploy_traefik(ops_test: OpsTest): channel=TRAEFIK_CHARM_CHANNEL, trust=True, ) + # TODO: This is a workaround so Traefik has the same CA as NMS. + # This should be removed and V1 of the certificate transfer interface should be used instead + # The following PR is needed to get Traefik to implement V1 of certificate transfer interface: + # https://github.com/canonical/traefik-k8s-operator/issues/407 + await ops_test.model.integrate( + relation1=f"{TRAEFIK_CHARM_NAME}:certificates", + relation2=TLS_PROVIDER_CHARM_NAME + ) async def configure_traefik(ops_test: OpsTest, traefik_ip: str) -> None: @@ -174,9 +181,9 @@ async def get_traefik_ip_address(ops_test: OpsTest) -> str: return _get_host_from_url(endpoints[TRAEFIK_CHARM_NAME]["url"]) -async def get_sdcore_nms_endpoint(ops_test: OpsTest) -> str: +async def get_sdcore_nms_external_endpoint(ops_test: OpsTest) -> str: endpoints = await get_traefik_proxied_endpoints(ops_test) - return endpoints[APP_NAME]["url"] + return endpoints[APP_NAME]["url"].rstrip("/") def _get_host_from_url(url: str) -> str: @@ -185,12 +192,13 @@ def _get_host_from_url(url: str) -> str: def ui_is_running(nms_endpoint: str) -> bool: - url = f"{nms_endpoint}network-configuration" + url = f"{nms_endpoint}/network-configuration" + logger.info(f"Reaching NMS UI at {url}") t0 = time.time() timeout = 300 # seconds while time.time() - t0 < timeout: try: - response = requests.get(url=url, timeout=5) + response = requests.get(url=url, timeout=5, verify=False) response.raise_for_status() logger.info(response.content.decode("utf-8")) if "5G NMS" in response.content.decode("utf-8"): @@ -201,20 +209,6 @@ def ui_is_running(nms_endpoint: str) -> bool: return False -def get_nms_inventory_resource(url: str) -> List: - t0 = time.time() - timeout = 100 # seconds - while time.time() - t0 < timeout: - try: - response = requests.get(url=url, timeout=5) - response.raise_for_status() - return response.json() - except Exception as e: - logger.error("Cannot connect to the nms inventory: %s", e) - time.sleep(2) - return [] - - @pytest.fixture(scope="module") @pytest.mark.abort_on_fail async def deploy(ops_test: OpsTest, request): @@ -232,8 +226,8 @@ async def deploy(ops_test: OpsTest, request): ) await _deploy_database(ops_test) await _deploy_grafana_agent(ops_test) - await _deploy_traefik(ops_test) await _deploy_self_signed_certificates(ops_test) + await _deploy_traefik(ops_test) await _deploy_nrf(ops_test) await _deploy_sdcore_gnbsim(ops_test) await _deploy_amf(ops_test) @@ -261,6 +255,7 @@ async def test_relate_and_wait_for_active_status(ops_test: OpsTest, deploy): await ops_test.model.integrate( relation1=f"{APP_NAME}:{AUTH_DATABASE_RELATION_NAME}", relation2=DATABASE_APP_NAME ) + await ops_test.model.integrate(relation1=APP_NAME, relation2=TLS_PROVIDER_CHARM_NAME) await ops_test.model.integrate( relation1=f"{APP_NAME}:{LOGGING_RELATION_NAME}", relation2=GRAFANA_AGENT_APP_NAME ) @@ -314,9 +309,9 @@ async def test_given_related_to_traefik_when_fetch_ui_then_returns_html_content( ): # Workaround for Traefik issue: https://github.com/canonical/traefik-k8s-operator/issues/361 traefik_ip = await get_traefik_ip_address(ops_test) - logger.info(traefik_ip) await configure_traefik(ops_test, traefik_ip) - nms_url = await get_sdcore_nms_endpoint(ops_test) + nms_url = await get_sdcore_nms_external_endpoint(ops_test) + assert nms_url.startswith("https") assert ui_is_running(nms_endpoint=nms_url) @@ -326,13 +321,15 @@ async def test_given_nms_related_to_gnbsim_and_gnbsim_status_is_active_then_nms_ ): assert ops_test.model await ops_test.model.wait_for_idle(apps=[GNBSIM_CHARM_NAME], status="active", timeout=TIMEOUT) - nms_url = await get_sdcore_nms_endpoint(ops_test) + nms_url = await get_sdcore_nms_external_endpoint(ops_test) nms_client = NMS(url=nms_url) gnbs = nms_client.list_gnbs() expected_gnb_name = f"{ops_test.model.name}-gnbsim-{GNBSIM_CHARM_NAME}" expected_gnb = GnodeB(name=expected_gnb_name, tac=1) + logger.info(expected_gnb_name) + logger.info(expected_gnb) assert gnbs == [expected_gnb] @@ -342,7 +339,7 @@ async def test_given_nms_related_to_upf_and_upf_status_is_active_then_nms_invent ): assert ops_test.model await ops_test.model.wait_for_idle(apps=[UPF_CHARM_NAME], status="active", timeout=TIMEOUT) - nms_url = await get_sdcore_nms_endpoint(ops_test) + nms_url = await get_sdcore_nms_external_endpoint(ops_test) nms_client = NMS(url=nms_url) upfs = nms_client.list_upfs() @@ -360,7 +357,8 @@ async def test_given_gnb_and_upf_are_remove_then_nms_inventory_does_not_contain_ await ops_test.model.remove_application(UPF_CHARM_NAME, block_until_done=False) await ops_test.model.remove_application(GNBSIM_CHARM_NAME, block_until_done=True) await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=TIMEOUT) - nms_url = await get_sdcore_nms_endpoint(ops_test) + nms_url = await get_sdcore_nms_external_endpoint(ops_test) + nms_client = NMS(url=nms_url) gnbs = nms_client.list_gnbs() @@ -370,6 +368,21 @@ async def test_given_gnb_and_upf_are_remove_then_nms_inventory_does_not_contain_ assert upfs == [] +@pytest.mark.abort_on_fail +async def test_remove_tls_and_wait_for_blocked_status(ops_test: OpsTest, deploy): + assert ops_test.model + await ops_test.model.remove_application(TLS_PROVIDER_CHARM_NAME, block_until_done=True) + await ops_test.model.wait_for_idle(apps=[APP_NAME], status="blocked", timeout=60) + + +@pytest.mark.abort_on_fail +async def test_restore_tls_and_wait_for_active_status(ops_test: OpsTest, deploy): + assert ops_test.model + await _deploy_self_signed_certificates(ops_test) + await ops_test.model.integrate(relation1=APP_NAME, relation2=TLS_PROVIDER_CHARM_NAME) + await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=TIMEOUT) + + @pytest.mark.abort_on_fail async def test_when_scale_app_beyond_1_then_only_one_unit_is_active(ops_test: OpsTest, deploy): assert ops_test.model diff --git a/tests/unit/certificates_helpers.py b/tests/unit/certificates_helpers.py new file mode 100644 index 00000000..cf77cc99 --- /dev/null +++ b/tests/unit/certificates_helpers.py @@ -0,0 +1,41 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +from datetime import timedelta + +from charms.tls_certificates_interface.v4.tls_certificates import ( + PrivateKey, + ProviderCertificate, + generate_ca, + generate_certificate, + generate_csr, + generate_private_key, +) + + +def example_cert_and_key(relation_id: int = 1) -> tuple[ProviderCertificate, PrivateKey]: + private_key = generate_private_key() + csr = generate_csr( + private_key=private_key, + common_name="nms", + ) + ca_private_key = generate_private_key() + ca_certificate = generate_ca( + private_key=ca_private_key, + common_name="ca.com", + validity=timedelta(days=365), + ) + certificate = generate_certificate( + csr=csr, + ca=ca_certificate, + ca_private_key=ca_private_key, + validity=timedelta(days=365), + ) + provider_certificate = ProviderCertificate( + relation_id=relation_id, + certificate=certificate, + certificate_signing_request=csr, + ca=ca_certificate, + chain=[ca_certificate], + ) + return provider_certificate, private_key diff --git a/tests/unit/expected_nms_cfg.yaml b/tests/unit/expected_nms_cfg.yaml index de007cc5..a36c6240 100644 --- a/tests/unit/expected_nms_cfg.yaml +++ b/tests/unit/expected_nms_cfg.yaml @@ -8,9 +8,13 @@ configuration: authKeysDbName: authentication authUrl: 1.8.11.4:1234 spec-compliant-sdf: false + tls: + key: /support/TLS/nms.key + pem: /support/TLS/nms.pem info: description: WebUI initial local configuration version: 1.0.0 + http-version: 2 logger: AMF: ReportCaller: false diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 979e998b..e44e3c3b 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -9,7 +9,7 @@ from charm import SDCoreNMSOperatorCharm -class NMSUnitTestFixtures: +class BaseNMSUnitTestFixtures: patcher_check_output = patch("charm.check_output") patcher_set_webui_url_in_all_relations = patch( "charms.sdcore_nms_k8s.v0.sdcore_config.SdcoreConfigProvides.set_webui_url_in_all_relations" @@ -21,20 +21,17 @@ class NMSUnitTestFixtures: patcher_nms_create_upf = patch("nms.NMS.create_upf") patcher_nms_delete_upf = patch("nms.NMS.delete_upf") - @pytest.fixture(autouse=True) - def setUp(self, request): - self.mock_check_output = NMSUnitTestFixtures.patcher_check_output.start() + def common_setup(self): + self.mock_check_output = BaseNMSUnitTestFixtures.patcher_check_output.start() self.mock_set_webui_url_in_all_relations = ( - NMSUnitTestFixtures.patcher_set_webui_url_in_all_relations.start() + BaseNMSUnitTestFixtures.patcher_set_webui_url_in_all_relations.start() ) - self.mock_list_gnbs = NMSUnitTestFixtures.patcher_nms_list_gnbs.start() - self.mock_create_gnb = NMSUnitTestFixtures.patcher_nms_create_gnb.start() - self.mock_delete_gnb = NMSUnitTestFixtures.patcher_nms_delete_gnb.start() - self.mock_list_upfs = NMSUnitTestFixtures.patcher_nms_list_upfs.start() - self.mock_create_upf = NMSUnitTestFixtures.patcher_nms_create_upf.start() - self.mock_delete_upf = NMSUnitTestFixtures.patcher_nms_delete_upf.start() - yield - request.addfinalizer(self.tearDown) + self.mock_list_gnbs = BaseNMSUnitTestFixtures.patcher_nms_list_gnbs.start() + self.mock_create_gnb = BaseNMSUnitTestFixtures.patcher_nms_create_gnb.start() + self.mock_delete_gnb = BaseNMSUnitTestFixtures.patcher_nms_delete_gnb.start() + self.mock_list_upfs = BaseNMSUnitTestFixtures.patcher_nms_list_upfs.start() + self.mock_create_upf = BaseNMSUnitTestFixtures.patcher_nms_create_upf.start() + self.mock_delete_upf = BaseNMSUnitTestFixtures.patcher_nms_delete_upf.start() @staticmethod def tearDown() -> None: @@ -45,3 +42,35 @@ def context(self): self.ctx = scenario.Context( charm_type=SDCoreNMSOperatorCharm, ) + + +class NMSUnitTestFixtures(BaseNMSUnitTestFixtures): + patcher_certificate_is_available = patch("tls.Tls.certificate_is_available") + patcher_check_and_update_certificate = patch("tls.Tls.check_and_update_certificate") + + @pytest.fixture(autouse=True) + def setUp(self, request): + self.common_setup() + self.mock_certificate_is_available = ( + NMSUnitTestFixtures.patcher_certificate_is_available.start() + ) + self.mock_check_and_update_certificate = ( + NMSUnitTestFixtures.patcher_check_and_update_certificate.start() + ) + yield + request.addfinalizer(self.tearDown) + +class NMSTlsCertificatesFixtures(BaseNMSUnitTestFixtures): + + patcher_get_assigned_certificate = patch( + "charms.tls_certificates_interface.v4.tls_certificates.TLSCertificatesRequiresV4.get_assigned_certificate" + ) + + @pytest.fixture(autouse=True) + def setUp(self, request): + self.common_setup() + self.mock_get_assigned_certificate = ( + NMSTlsCertificatesFixtures.patcher_get_assigned_certificate.start() + ) + yield + request.addfinalizer(self.tearDown) diff --git a/tests/unit/test_charm_collect_status.py b/tests/unit/test_charm_collect_status.py index e382320e..48bb00bf 100644 --- a/tests/unit/test_charm_collect_status.py +++ b/tests/unit/test_charm_collect_status.py @@ -1,7 +1,6 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. - import tempfile import scenario @@ -26,7 +25,13 @@ def test_given_common_database_relation_not_created_when_collect_unit_status_the endpoint="auth_database", interface="mongodb_client", ) - state_in = scenario.State(leader=True, relations={auth_database_relation}) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + state_in = scenario.State( + leader=True, + relations={auth_database_relation, certificates_relation} + ) state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) @@ -41,7 +46,13 @@ def test_given_auth_database_relation_not_created_when_collect_unit_status_then_ endpoint="common_database", interface="mongodb_client", ) - state_in = scenario.State(leader=True, relations={common_database_relation}) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + state_in = scenario.State( + leader=True, + relations={common_database_relation, certificates_relation} + ) state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) @@ -49,6 +60,28 @@ def test_given_auth_database_relation_not_created_when_collect_unit_status_then_ "Waiting for auth_database relation to be created" ) + def test_given_certificates_relation_not_created_when_collect_unit_status_then_status_is_blocked( # noqa: E501 + self, + ): + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + ) + state_in = scenario.State( + leader=True, + relations={common_database_relation, auth_database_relation} + ) + + state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) + + assert state_out.unit_status == BlockedStatus( + "Waiting for certificates relation to be created" + ) + def test_given_common_db_relation_is_created_but_not_available_when_collect_unit_status_then_status_is_waiting( # noqa: E501 self, ): @@ -64,9 +97,12 @@ def test_given_common_db_relation_is_created_but_not_available_when_collect_unit common_database_relation = scenario.Relation( endpoint="common_database", interface="mongodb_client", remote_app_data={} ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={auth_database_relation, common_database_relation, certificates_relation}, ) state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) @@ -90,9 +126,12 @@ def test_given_auth_db_relation_is_created_but_not_available_when_collect_unit_s "uris": "1.2.3.4:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={auth_database_relation, common_database_relation, certificates_relation}, ) state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) @@ -122,13 +161,16 @@ def test_given_storage_attached_but_cannot_connect_to_container_when_collect_uni "uris": "2.3.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) container = scenario.Container( name="nms", can_connect=False, ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={auth_database_relation, common_database_relation, certificates_relation}, containers={container}, ) @@ -157,6 +199,9 @@ def test_given_storage_not_attached_when_collect_unit_status_then_status_is_wait "uris": "11.11.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) container = scenario.Container( name="nms", @@ -165,7 +210,7 @@ def test_given_storage_not_attached_when_collect_unit_status_then_status_is_wait ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={auth_database_relation, common_database_relation, certificates_relation}, containers={container}, ) @@ -173,7 +218,55 @@ def test_given_storage_not_attached_when_collect_unit_status_then_status_is_wait assert state_out.unit_status == WaitingStatus("Waiting for storage to be attached") - def test_given_nms_config_file_does_not_exist_when_collect_unit_status_then_status_is_waiting( # noqa: E501 + def test_given_config_storage_not_attached_when_collect_unit_status_then_status_is_waiting( + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "apple", + "password": "hamburger", + "uris": "1.8.11.4:1234", + }, + ) + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "11.11.1.1:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + + container = scenario.Container( + name="nms", + can_connect=True, + mounts={"certs": certs_mount}, + ) + state_in = scenario.State( + leader=True, + relations={auth_database_relation, + common_database_relation, + certificates_relation, + }, + containers={container}, + ) + + state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) + + assert state_out.unit_status == WaitingStatus("Waiting for storage to be attached") + + def test_given_certs_storage_not_attached_when_collect_unit_status_then_status_is_waiting( self, ): with tempfile.TemporaryDirectory() as tempdir: @@ -195,10 +288,14 @@ def test_given_nms_config_file_does_not_exist_when_collect_unit_status_then_stat "uris": "11.11.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + container = scenario.Container( name="nms", can_connect=True, @@ -206,16 +303,128 @@ def test_given_nms_config_file_does_not_exist_when_collect_unit_status_then_stat ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={ + auth_database_relation, + common_database_relation, + certificates_relation, + }, + containers={container}, + ) + + state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) + + assert state_out.unit_status == WaitingStatus("Waiting for storage to be attached") + + def test_given_nms_config_file_does_not_exist_when_collect_unit_status_then_status_is_waiting( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "apple", + "password": "hamburger", + "uris": "1.8.11.4:1234", + }, + ) + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "11.11.1.1:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={"config": config_mount, "certs": certs_mount}, + ) + state_in = scenario.State( + leader=True, + relations={auth_database_relation, + common_database_relation, + certificates_relation, + }, containers={container}, ) state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) assert state_out.unit_status == WaitingStatus( - "Waiting for nms config file to be stored" + "Waiting for NMS config file to be stored" ) + def test_given_certificates_not_stored_when_collect_unit_status_then_status_is_waiting( + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "apple", + "password": "hamburger", + "uris": "1.2.3.4:1234", + }, + ) + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "2.2.2.2:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={"config": config_mount, "certs": certs_mount,}, + ) + state_in = scenario.State( + leader=True, + relations={ + auth_database_relation, + common_database_relation, + certificates_relation + }, + containers={container}, + ) + self.mock_certificate_is_available.return_value = False + with open(f"{tempdir}/nmscfg.conf", "w") as f: + f.write("whatever config file content") + + state_out = self.ctx.run(self.ctx.on.collect_unit_status(), state_in) + + assert state_out.unit_status == WaitingStatus( + "Waiting for certificates to be available" + ) + def test_given_service_is_not_running_when_collect_unit_status_then_status_is_waiting( self, ): @@ -238,20 +447,32 @@ def test_given_service_is_not_running_when_collect_unit_status_then_status_is_wa "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, - mounts={"config": config_mount}, + mounts={"config": config_mount, "certs": certs_mount}, ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={ + auth_database_relation, + common_database_relation, + certificates_relation + }, containers={container}, ) + self.mock_certificate_is_available.return_value = True with open(f"{tempdir}/nmscfg.conf", "w") as f: f.write("whatever config file content") @@ -281,22 +502,35 @@ def test_given_container_ready_db_relations_exist_storage_attached_and_config_fi "uris": "1.1.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, - mounts={"config": config_mount}, + mounts={"config": config_mount, "certs": certs_mount}, layers={"nms": Layer({"services": {"nms": {}}})}, service_statuses={"nms": ServiceStatus.ACTIVE}, ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={ + auth_database_relation, + common_database_relation, + certificates_relation, + }, containers={container}, ) + self.mock_certificate_is_available.return_value = True + with open(f"{tempdir}/nmscfg.conf", "w") as f: f.write("whatever config file content") @@ -325,6 +559,9 @@ def test_given_no_workload_version_file_when_collect_unit_status_then_workload_v "uris": "1.1.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) container = scenario.Container( name="nms", @@ -334,7 +571,7 @@ def test_given_no_workload_version_file_when_collect_unit_status_then_workload_v ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={auth_database_relation, common_database_relation, certificates_relation}, containers={container}, ) @@ -369,6 +606,10 @@ def test_given_workload_version_file_when_collect_unit_status_then_workload_vers "uris": "1.1.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + container = scenario.Container( name="nms", can_connect=True, @@ -376,7 +617,11 @@ def test_given_workload_version_file_when_collect_unit_status_then_workload_vers ) state_in = scenario.State( leader=True, - relations={auth_database_relation, common_database_relation}, + relations={ + auth_database_relation, + common_database_relation, + certificates_relation, + }, containers={container}, ) diff --git a/tests/unit/test_charm_configure.py b/tests/unit/test_charm_configure.py index 6e930cc1..89ed9a74 100644 --- a/tests/unit/test_charm_configure.py +++ b/tests/unit/test_charm_configure.py @@ -16,8 +16,8 @@ EXPECTED_CONFIG_FILE_PATH = "tests/unit/expected_nms_cfg.yaml" - class TestCharmConfigure(NMSUnitTestFixtures): + def test_given_db_relations_do_not_exist_when_pebble_ready_then_nms_config_file_is_not_written( # noqa: E501 self, ): @@ -59,22 +59,35 @@ def test_given_common_db_resource_not_available_when_pebble_ready_then_nms_confi "uris": "11.11.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, ) + self.mock_check_and_update_certificate.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -97,28 +110,41 @@ def test_given_auth_db_resource_not_available_when_pebble_ready_then_nms_config_ endpoint="auth_database", interface="mongodb_client", ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, ) + self.mock_check_and_update_certificate.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) assert not os.path.exists(f"{tempdir}/nmscfg.conf") - def test_given_storage_attached_and_nms_config_file_does_not_exist_when_pebble_ready_then_config_file_is_written( # noqa: E501 + def test_given_certificates_relation_doesnt_exist_when_pebble_ready_then_nms_config_file_is_not_written( # noqa: E501 self, ): with tempfile.TemporaryDirectory() as tempdir: @@ -144,11 +170,16 @@ def test_given_storage_attached_and_nms_config_file_does_not_exist_when_pebble_r location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -159,6 +190,125 @@ def test_given_storage_attached_and_nms_config_file_does_not_exist_when_pebble_r self.ctx.run(self.ctx.on.pebble_ready(container), state_in) + assert not os.path.exists(f"{tempdir}/nmscfg.conf") + + def test_given_tls_certificate_not_available_when_pebble_ready_then_nms_config_file_is_not_written( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.2.3.4:5678", + }, + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "11.11.1.1:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={ + "config": config_mount, + "certs": certs_mount, + }, + ) + state_in = scenario.State( + leader=True, + containers={container}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, + ) + self.mock_certificate_is_available.return_value = False + + self.ctx.run(self.ctx.on.pebble_ready(container), state_in) + + assert not os.path.exists(f"{tempdir}/nmscfg.conf") + + @pytest.mark.parametrize( + "certificate_was_updated", + [ + True, + False, + ] + ) + def test_given_storage_attached_and_nms_config_file_does_not_exist_when_pebble_ready_then_config_file_is_written( # noqa: E501 + self, certificate_was_updated + ): + with tempfile.TemporaryDirectory() as tempdir: + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.9.11.4:1234", + }, + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.8.11.4:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={ + "config": config_mount, + "certs": certs_mount, + }, + ) + state_in = scenario.State( + leader=True, + containers={container}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, + ) + self.mock_check_and_update_certificate.return_value = certificate_was_updated + + self.ctx.run(self.ctx.on.pebble_ready(container), state_in) + assert os.path.exists(f"{tempdir}/nmscfg.conf") with open(f"{tempdir}/nmscfg.conf", "r") as f: assert f.read() == open(EXPECTED_CONFIG_FILE_PATH, "r").read() @@ -185,22 +335,35 @@ def test_given_container_is_ready_db_relations_exist_and_storage_attached_when_p "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, ) + self.mock_certificate_is_available.return_value = True state_out = self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -212,7 +375,7 @@ def test_given_container_is_ready_db_relations_exist_and_storage_attached_when_p "nms": { "startup": "enabled", "override": "replace", - "command": "/bin/webconsole --webuicfg /nms/config/nmscfg.conf", + "command": "/bin/webconsole --cfg /nms/config/nmscfg.conf", "environment": { "GRPC_GO_LOG_VERBOSITY_LEVEL": "99", "GRPC_GO_LOG_SEVERITY_LEVEL": "info", @@ -226,17 +389,24 @@ def test_given_container_is_ready_db_relations_exist_and_storage_attached_when_p } ) - def test_given_db_relations_do_not_exist_when_pebble_ready_then_pebble_plan_is_empty(self): + def test_given_mandatory_relations_do_not_exist_when_pebble_ready_then_pebble_plan_is_empty( + self + ): with tempfile.TemporaryDirectory() as tempdir: config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -273,6 +443,9 @@ def test_given_storage_not_attached_when_pebble_ready_then_config_url_is_not_pub "uris": "2.1.1.1:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) container = scenario.Container( name="nms", can_connect=True, @@ -280,14 +453,20 @@ def test_given_storage_not_attached_when_pebble_ready_then_config_url_is_not_pub state_in = scenario.State( leader=True, containers={container}, - relations={sdcore_config_relation, common_database_relation, auth_database_relation}, + relations={ + sdcore_config_relation, + common_database_relation, + auth_database_relation, + certificates_relation, + }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) self.mock_set_webui_url_in_all_relations.assert_not_called() - def test_given_nms_service_is_running_db_relations_are_not_joined_when_pebble_ready_then_config_url_is_not_published_for_relations( # noqa: E501 + def test_given_nms_service_is_running_mandatory_relations_are_not_joined_when_pebble_ready_then_config_url_is_not_published_for_relations( # noqa: E501 self, ): with tempfile.TemporaryDirectory() as tempdir: @@ -338,7 +517,9 @@ def test_given_nms_service_is_running_db_relations_are_joined_when_several_sdcor "uris": "2.2.2.2:1234", }, ) - + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) sdcore_config_relation_1 = scenario.Relation( endpoint="sdcore_config", interface="sdcore_config", @@ -351,11 +532,16 @@ def test_given_nms_service_is_running_db_relations_are_joined_when_several_sdcor location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -364,11 +550,12 @@ def test_given_nms_service_is_running_db_relations_are_joined_when_several_sdcor relations={ auth_database_relation, common_database_relation, + certificates_relation, sdcore_config_relation_1, sdcore_config_relation_2, }, ) - + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) self.mock_set_webui_url_in_all_relations.assert_called_with( @@ -397,6 +584,9 @@ def test_given_nms_service_is_not_running_when_pebble_ready_then_config_url_is_n "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) sdcore_config_relation = scenario.Relation( endpoint="sdcore_config", interface="sdcore_config", @@ -405,11 +595,16 @@ def test_given_nms_service_is_not_running_when_pebble_ready_then_config_url_is_n location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=False, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -418,9 +613,11 @@ def test_given_nms_service_is_not_running_when_pebble_ready_then_config_url_is_n relations={ auth_database_relation, common_database_relation, + certificates_relation, sdcore_config_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -545,6 +742,9 @@ def test_given_incomplete_data_in_relation_when_pebble_ready_then_is_not_updated "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) relation = scenario.Relation( endpoint=relation_name, interface=relation_name, @@ -563,9 +763,15 @@ def test_given_incomplete_data_in_relation_when_pebble_ready_then_is_not_updated ) state_in = scenario.State( leader=True, - relations={relation, auth_database_relation, common_database_relation}, + relations={ + relation, + auth_database_relation, + common_database_relation, + certificates_relation, + }, containers={container}, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -574,7 +780,9 @@ def test_given_incomplete_data_in_relation_when_pebble_ready_then_is_not_updated self.mock_delete_gnb.assert_not_called() self.mock_delete_upf.assert_not_called() - def test_given_no_db_relations_when_pebble_ready_then_nms_resources_are_not_updated(self): + def test_given_no_mandatory_relations_when_pebble_ready_then_nms_inventory_is_not_updated( + self + ): with tempfile.TemporaryDirectory() as tempdir: fiveg_gnb_identity_relation = scenario.Relation( endpoint="fiveg_gnb_identity", @@ -616,7 +824,7 @@ def test_given_no_db_relations_when_pebble_ready_then_nms_resources_are_not_upda self.mock_create_upf.assert_not_called() self.mock_delete_upf.assert_not_called() - def test_given_db_relations_when_pebble_ready_then_nms_upf_is_updated( + def test_given_mandatory_relations_when_pebble_ready_then_nms_upf_is_updated( self, ): with tempfile.TemporaryDirectory() as tempdir: @@ -638,6 +846,9 @@ def test_given_db_relations_when_pebble_ready_then_nms_upf_is_updated( "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) fiveg_gnb_identity_relation = scenario.Relation( endpoint="fiveg_gnb_identity", interface="fiveg_gnb_identity", @@ -658,11 +869,16 @@ def test_given_db_relations_when_pebble_ready_then_nms_upf_is_updated( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -671,16 +887,18 @@ def test_given_db_relations_when_pebble_ready_then_nms_upf_is_updated( relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_gnb_identity_relation, fiveg_n4_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) self.mock_create_upf.assert_called_once_with(hostname="some.host.name", port=1234) - def test_given_db_relations_when_pebble_ready_then_nms_gnb_is_updated( + def test_given_mandatory_relations_when_pebble_ready_then_nms_gnb_is_updated( self, ): with tempfile.TemporaryDirectory() as tempdir: @@ -702,6 +920,9 @@ def test_given_db_relations_when_pebble_ready_then_nms_gnb_is_updated( "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) fiveg_gnb_identity_relation = scenario.Relation( endpoint="fiveg_gnb_identity", interface="fiveg_gnb_identity", @@ -722,11 +943,16 @@ def test_given_db_relations_when_pebble_ready_then_nms_gnb_is_updated( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -735,10 +961,12 @@ def test_given_db_relations_when_pebble_ready_then_nms_gnb_is_updated( relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_gnb_identity_relation, fiveg_n4_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -766,6 +994,9 @@ def test_given_multiple_n4_relations_when_pebble_ready_then_both_upfs_are_added_ "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) fiveg_n4_relation_1 = scenario.Relation( endpoint="fiveg_n4", interface="fiveg_n4", @@ -786,11 +1017,16 @@ def test_given_multiple_n4_relations_when_pebble_ready_then_both_upfs_are_added_ location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -799,10 +1035,12 @@ def test_given_multiple_n4_relations_when_pebble_ready_then_both_upfs_are_added_ relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_n4_relation_1, fiveg_n4_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -834,6 +1072,9 @@ def test_given_multiple_gnb_relations_when_pebble_ready_then_both_gnbs_are_added "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) fiveg_gnb_identity_relation_1 = scenario.Relation( endpoint="fiveg_gnb_identity", interface="fiveg_gnb_identity", @@ -854,11 +1095,16 @@ def test_given_multiple_gnb_relations_when_pebble_ready_then_both_gnbs_are_added location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) state_in = scenario.State( @@ -867,10 +1113,12 @@ def test_given_multiple_gnb_relations_when_pebble_ready_then_both_gnbs_are_added relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_gnb_identity_relation_1, fiveg_gnb_identity_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -902,16 +1150,24 @@ def test_given_upf_exist_in_nms_and_relation_matches_when_pebble_ready_then_nms_ "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) self.mock_list_upfs.return_value = [Upf(hostname="some.host.name", port=1234)] config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_n4_relation = scenario.Relation( @@ -925,8 +1181,14 @@ def test_given_upf_exist_in_nms_and_relation_matches_when_pebble_ready_then_nms_ state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation, fiveg_n4_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + fiveg_n4_relation, + }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -955,17 +1217,25 @@ def test_given_gnb_exist_in_nms_and_relation_matches_when_pebble_ready_then_nms_ "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_gnbs = [GnodeB(name="some.gnb.name", tac=1234)] self.mock_list_gnbs.return_value = existing_gnbs config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_gnb_identity_relation = scenario.Relation( @@ -982,9 +1252,11 @@ def test_given_gnb_exist_in_nms_and_relation_matches_when_pebble_ready_then_nms_ relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_gnb_identity_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.pebble_ready(container), state_in) @@ -1038,17 +1310,25 @@ def test_given_upf_exists_in_nms_and_new_upf_relation_is_added_when_pebble_ready "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_upf = Upf(hostname="some.host.name", port=1234) self.mock_list_upfs.return_value = [existing_upf] config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_n4_relation_1 = scenario.Relation( @@ -1073,10 +1353,12 @@ def test_given_upf_exists_in_nms_and_new_upf_relation_is_added_when_pebble_ready relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_n4_relation_1, fiveg_n4_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(fiveg_n4_relation_2), state_in) @@ -1105,17 +1387,25 @@ def test_given_gnb_exists_in_nms_and_new_gnb_relation_is_added_when_pebble_ready "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_gnbs = [GnodeB(name="some.gnb.name", tac=1234)] self.mock_list_gnbs.return_value = existing_gnbs config_mount = scenario.Mount( location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_gnb_identity_relation_1 = scenario.Relation( @@ -1140,10 +1430,12 @@ def test_given_gnb_exists_in_nms_and_new_gnb_relation_is_added_when_pebble_ready relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_gnb_identity_relation_1, fiveg_gnb_identity_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(fiveg_gnb_identity_relation_2), state_in) @@ -1172,6 +1464,9 @@ def test_given_two_n4_relations_when_n4_relation_broken_then_upf_is_removed_from "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_upfs = [ Upf(hostname="some.host.name", port=1234), Upf(hostname="some.host", port=22), @@ -1181,11 +1476,16 @@ def test_given_two_n4_relations_when_n4_relation_broken_then_upf_is_removed_from location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_n4_relation_1 = scenario.Relation( @@ -1210,10 +1510,12 @@ def test_given_two_n4_relations_when_n4_relation_broken_then_upf_is_removed_from relations={ common_database_relation, auth_database_relation, + certificates_relation, fiveg_n4_relation_1, fiveg_n4_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_broken(fiveg_n4_relation_1), state_in) @@ -1242,6 +1544,9 @@ def test_given_two_gnb_identity_relations_when_relation_broken_then_gnb_is_remov "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_gnbs = [ GnodeB(name="some.gnb.name", tac=1234), GnodeB(name="gnb.name", tac=333), @@ -1251,11 +1556,16 @@ def test_given_two_gnb_identity_relations_when_relation_broken_then_gnb_is_remov location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) gnb_identity_relation_1 = scenario.Relation( @@ -1280,10 +1590,12 @@ def test_given_two_gnb_identity_relations_when_relation_broken_then_gnb_is_remov relations={ common_database_relation, auth_database_relation, + certificates_relation, gnb_identity_relation_1, gnb_identity_relation_2, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_broken(gnb_identity_relation_1), state_in) @@ -1312,6 +1624,13 @@ def test_given_one_upf_in_nms_when_upf_is_modified_in_relation_then_nms_upfs_are "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) existing_upfs = [ Upf(hostname="some.host.name", port=1234), ] @@ -1325,6 +1644,7 @@ def test_given_one_upf_in_nms_when_upf_is_modified_in_relation_then_nms_upfs_are can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_n4_relation = scenario.Relation( @@ -1338,8 +1658,14 @@ def test_given_one_upf_in_nms_when_upf_is_modified_in_relation_then_nms_upfs_are state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation, fiveg_n4_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + fiveg_n4_relation, + }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(fiveg_n4_relation), state_in) @@ -1368,6 +1694,9 @@ def test_given_one_gnb_in_nms_when_gnb_is_modified_in_relation_then_nms_gnbs_are "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_gnbs = [ GnodeB(name="some.gnb.name", tac=1234), ] @@ -1376,11 +1705,16 @@ def test_given_one_gnb_in_nms_when_gnb_is_modified_in_relation_then_nms_gnbs_are location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) gnb_identity_relation = scenario.Relation( @@ -1397,9 +1731,11 @@ def test_given_one_gnb_in_nms_when_gnb_is_modified_in_relation_then_nms_gnbs_are relations={ common_database_relation, auth_database_relation, + certificates_relation, gnb_identity_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(gnb_identity_relation), state_in) @@ -1428,6 +1764,9 @@ def test_given_one_upf_in_nms_when_new_upf_is_added_then_old_upf_is_removed_and_ "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_upfs = [ Upf(hostname="old.name", port=1234), ] @@ -1436,11 +1775,16 @@ def test_given_one_upf_in_nms_when_new_upf_is_added_then_old_upf_is_removed_and_ location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) fiveg_n4_relation = scenario.Relation( @@ -1454,8 +1798,14 @@ def test_given_one_upf_in_nms_when_new_upf_is_added_then_old_upf_is_removed_and_ state_in = scenario.State( leader=True, containers={container}, - relations={common_database_relation, auth_database_relation, fiveg_n4_relation}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + fiveg_n4_relation, + }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(fiveg_n4_relation), state_in) @@ -1484,6 +1834,9 @@ def test_given_one_gnb_in_nms_when_new_gnb_is_added_then_old_gnb_is_removed_and_ "uris": "2.2.2.2:1234", }, ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) existing_gnbs = [ GnodeB(name="old.gnb.name", tac=1234), ] @@ -1492,11 +1845,16 @@ def test_given_one_gnb_in_nms_when_new_gnb_is_added_then_old_gnb_is_removed_and_ location="/nms/config", source=tempdir, ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) container = scenario.Container( name="nms", can_connect=True, mounts={ "config": config_mount, + "certs": certs_mount, }, ) gnb_identity_relation = scenario.Relation( @@ -1513,11 +1871,81 @@ def test_given_one_gnb_in_nms_when_new_gnb_is_added_then_old_gnb_is_removed_and_ relations={ common_database_relation, auth_database_relation, + certificates_relation, gnb_identity_relation, }, ) + self.mock_certificate_is_available.return_value = True self.ctx.run(self.ctx.on.relation_joined(gnb_identity_relation), state_in) self.mock_delete_gnb.assert_called_once_with(name="old.gnb.name") self.mock_create_gnb.assert_called_once_with(name="some.gnb.name", tac=6789) + + def test_given_cannot_connect_to_container_when_certificates_relation_broken_then_certificates_are_not_removed( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.1.1.1:1234", + }, + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "2.2.2.2:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + existing_gnbs = [ + GnodeB(name="old.gnb.name", tac=1234), + ] + self.mock_list_gnbs.return_value = existing_gnbs + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=False, + mounts={ + "config": config_mount, + "certs": certs_mount, + }, + ) + gnb_identity_relation = scenario.Relation( + endpoint="fiveg_gnb_identity", + interface="fiveg_gnb_identity", + remote_app_data={ + "gnb_name": "some.gnb.name", + "tac": "6789", + }, + ) + state_in = scenario.State( + leader=True, + containers={container}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + gnb_identity_relation, + }, + ) + self.mock_certificate_is_available.return_value = True + + self.ctx.run(self.ctx.on.relation_broken(certificates_relation), state_in) + diff --git a/tests/unit/test_charm_tls_certificates.py b/tests/unit/test_charm_tls_certificates.py new file mode 100644 index 00000000..99cf1387 --- /dev/null +++ b/tests/unit/test_charm_tls_certificates.py @@ -0,0 +1,317 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import os +import tempfile + +import scenario +from ops import testing + +from tests.unit.certificates_helpers import example_cert_and_key +from tests.unit.fixtures import NMSTlsCertificatesFixtures + + +class TestCharmTlsCertificates(NMSTlsCertificatesFixtures): + def test_given_certificates_are_stored_when_on_certificates_relation_broken_then_certificates_are_removed( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + certificates_relation = testing.Relation( + endpoint="certificates", interface="tls-certificates" + ) + certs_mount = testing.Mount( + location="/support/TLS", + source=tempdir, + ) + config_mount = testing.Mount( + location="/nms/config", + source=tempdir, + ) + container = testing.Container( + name="nms", + can_connect=True, + mounts={"certs": certs_mount, "config": config_mount}, + ) + os.mkdir(f"{tempdir}/support") + os.mkdir(f"{tempdir}/support/TLS") + with open(f"{tempdir}/nms.pem", "w") as f: + f.write("certificate") + + with open(f"{tempdir}/nms.key", "w") as f: + f.write("private key") + + with open(f"{tempdir}/ca.pem", "w") as f: + f.write("CA certificate") + + state_in = testing.State( + relations=[certificates_relation], + containers=[container], + leader=True, + ) + + self.ctx.run(self.ctx.on.relation_broken(certificates_relation), state_in) + + assert not os.path.exists(f"{tempdir}/nms.pem") + assert not os.path.exists(f"{tempdir}/nms.key") + assert not os.path.exists(f"{tempdir}/ca.pem") + + def test_given_cannot_connect_to_container_when_on_certificates_relation_broken_then_certificates_are_not_removed( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + certificates_relation = testing.Relation( + endpoint="certificates", interface="tls-certificates" + ) + certs_mount = testing.Mount( + location="/support/TLS", + source=tempdir, + ) + config_mount = testing.Mount( + location="/nms/config", + source=tempdir, + ) + container = testing.Container( + name="nms", + can_connect=False, + mounts={"certs": certs_mount, "config": config_mount}, + ) + os.mkdir(f"{tempdir}/support") + os.mkdir(f"{tempdir}/support/TLS") + with open(f"{tempdir}/nms.pem", "w") as f: + f.write("certificate") + + with open(f"{tempdir}/nms.key", "w") as f: + f.write("private key") + + with open(f"{tempdir}/ca.pem", "w") as f: + f.write("CA certificate") + + state_in = testing.State( + relations=[certificates_relation], + containers=[container], + leader=True, + ) + + self.ctx.run(self.ctx.on.relation_broken(certificates_relation), state_in) + + assert os.path.exists(f"{tempdir}/nms.pem") + assert os.path.exists(f"{tempdir}/nms.key") + assert os.path.exists(f"{tempdir}/ca.pem") + + def test_given_certificate_matches_stored_one_when_pebble_ready_then_certificate_is_not_pushed( + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "apple", + "password": "hamburger", + "uris": "1.2.3.4:1234", + }, + ) + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.1.1.1:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={"config": config_mount, "certs": certs_mount}, + ) + state_in = testing.State( + leader=True, + relations=[ + auth_database_relation, + common_database_relation, + certificates_relation, + ], + containers={container}, + ) + provider_certificate, private_key = example_cert_and_key( + relation_id=certificates_relation.id + ) + self.mock_get_assigned_certificate.return_value = (provider_certificate, private_key) + with open(f"{tempdir}/nms.pem", "w") as f: + f.write(str(provider_certificate.certificate)) + with open(f"{tempdir}/nms.key", "w") as f: + f.write(str(private_key)) + with open(f"{tempdir}/ca.pem", "w") as f: + f.write(str(provider_certificate.ca)) + config_modification_time_nms_pem = os.stat(tempdir + "/nms.pem").st_mtime + config_modification_time_nms_key = os.stat(tempdir + "/nms.key").st_mtime + config_modification_time_ca_pem = os.stat(tempdir + "/ca.pem").st_mtime + + self.ctx.run(self.ctx.on.pebble_ready(container=container), state_in) + + assert os.stat(tempdir + "/nms.pem").st_mtime == config_modification_time_nms_pem + assert os.stat(tempdir + "/nms.key").st_mtime == config_modification_time_nms_key + assert os.stat(tempdir + "/ca.pem").st_mtime == config_modification_time_ca_pem + + def test_given_storage_attached_and_certificate_available_when_pebble_ready_then_certs_are_written( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.9.11.4:1234", + }, + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.8.11.4:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={ + "config": config_mount, + "certs": certs_mount, + }, + ) + state_in = scenario.State( + leader=True, + containers={container}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, + ) + provider_certificate, private_key = example_cert_and_key( + relation_id=certificates_relation.id + ) + self.mock_get_assigned_certificate.return_value = (provider_certificate, private_key) + + self.ctx.run(self.ctx.on.pebble_ready(container), state_in) + + with open(tempdir + "/nms.pem", "r") as f: + assert f.read() == str(provider_certificate.certificate) + with open(tempdir + "/nms.key", "r") as f: + assert f.read() == str(private_key) + with open(tempdir + "/ca.pem", "r") as f: + assert f.read() == str(provider_certificate.ca) + + def test_given_certificate_exist_and_are_different_when_pebble_ready_then_certs_are_overwritten( # noqa: E501 + self, + ): + with tempfile.TemporaryDirectory() as tempdir: + common_database_relation = scenario.Relation( + endpoint="common_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.9.11.4:1234", + }, + ) + auth_database_relation = scenario.Relation( + endpoint="auth_database", + interface="mongodb_client", + remote_app_data={ + "username": "banana", + "password": "pizza", + "uris": "1.8.11.4:1234", + }, + ) + certificates_relation = scenario.Relation( + endpoint="certificates", interface="tls-certificates" + ) + config_mount = scenario.Mount( + location="/nms/config", + source=tempdir, + ) + certs_mount = scenario.Mount( + location="/support/TLS", + source=tempdir, + ) + container = scenario.Container( + name="nms", + can_connect=True, + mounts={ + "config": config_mount, + "certs": certs_mount, + }, + ) + os.mkdir(f"{tempdir}/support") + os.mkdir(f"{tempdir}/support/TLS") + old_provider_certificate, old_private_key = example_cert_and_key( + relation_id=auth_database_relation.id + ) + with open(f"{tempdir}/nms.pem", "w") as f: + f.write(str(old_provider_certificate.certificate)) + + with open(f"{tempdir}/nms.key", "w") as f: + f.write(str(old_private_key)) + + with open(f"{tempdir}/ca.pem", "w") as f: + f.write(str(old_provider_certificate.ca)) + + state_in = scenario.State( + leader=True, + containers={container}, + relations={ + common_database_relation, + auth_database_relation, + certificates_relation, + }, + ) + new_provider_certificate, new_private_key = example_cert_and_key( + relation_id=certificates_relation.id + ) + assert new_provider_certificate.certificate != old_provider_certificate.certificate + assert new_provider_certificate.ca != old_provider_certificate.ca + assert new_private_key != old_private_key + + self.mock_get_assigned_certificate.return_value = ( + new_provider_certificate, + new_private_key + ) + + self.ctx.run(self.ctx.on.pebble_ready(container), state_in) + + with open(tempdir + "/nms.pem", "r") as f: + assert f.read() == str(new_provider_certificate.certificate) + with open(tempdir + "/nms.key", "r") as f: + assert f.read() == str(new_private_key) + with open(tempdir + "/ca.pem", "r") as f: + assert f.read() == str(new_provider_certificate.ca) + + diff --git a/tests/unit/test_nms.py b/tests/unit/test_nms.py index 6282e79d..e2e91352 100644 --- a/tests/unit/test_nms.py +++ b/tests/unit/test_nms.py @@ -69,6 +69,7 @@ def test_when_list_gnbs_then_gnb_url_is_used(self): url="some_url/config/v1/inventory/gnb", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) def test_given_nms_returns_a_gnb_list_when_list_gnbs_then_a_gnb_list_is_returned(self): @@ -129,6 +130,7 @@ def test_given_exception_is_raised_when_create_gnb_then_exception_is_handled(sel url="some_url/config/v1/inventory/gnb/some.gnb.name", headers={"Content-Type": "application/json"}, json={"tac": "111"}, + verify=False, ) def test_given_a_valid_gnb_when_create_gnb_then_gnb_is_added_to_nms(self): @@ -139,6 +141,7 @@ def test_given_a_valid_gnb_when_create_gnb_then_gnb_is_added_to_nms(self): url="some_url/config/v1/inventory/gnb/some.gnb.name", headers={"Content-Type": "application/json"}, json={"tac": "111"}, + verify=False, ) @pytest.mark.parametrize( @@ -163,6 +166,7 @@ def test_given_exception_is_raised_when_delete_gnb_then_exceptions_is_handled(se url="some_url/config/v1/inventory/gnb/some.gnb.name", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) def test_given_valid_gnb_when_delete_gnb_then_gnb_is_successfully_deleted(self): @@ -174,6 +178,7 @@ def test_given_valid_gnb_when_delete_gnb_then_gnb_is_successfully_deleted(self): url="some_url/config/v1/inventory/gnb/some.gnb.name", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) @pytest.mark.parametrize( @@ -207,6 +212,7 @@ def test_when_list_upfs_then_upf_url_is_used(self): url="some_url/config/v1/inventory/upf", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) def test_given_nms_returns_a_upf_list_when_list_upfs_then_a_upf_list_is_returned(self): @@ -267,6 +273,7 @@ def test_given_exception_is_raised_when_create_upf_then_exception_is_handled(sel url="some_url/config/v1/inventory/upf/some.upf.name", headers={"Content-Type": "application/json"}, json={"port": "111"}, + verify=False, ) def test_given_a_valid_upf_when_create_upf_then_upf_is_added_to_nms(self): @@ -277,6 +284,7 @@ def test_given_a_valid_upf_when_create_upf_then_upf_is_added_to_nms(self): url="some_url/config/v1/inventory/upf/some.upf.name", headers={"Content-Type": "application/json"}, json={"port": "22"}, + verify=False, ) @pytest.mark.parametrize( @@ -301,6 +309,7 @@ def test_given_exception_is_raised_when_delete_upf_then_exceptions_is_handled(se url="some_url/config/v1/inventory/upf/some.upf.name", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) def test_given_valid_upf_when_delete_upf_then_upf_is_successfully_deleted(self): @@ -313,6 +322,7 @@ def test_given_valid_upf_when_delete_upf_then_upf_is_successfully_deleted(self): url="some_url/config/v1/inventory/upf/some.upf.name", headers={"Content-Type": "application/json"}, json=None, + verify=False, ) @pytest.mark.parametrize( diff --git a/tests/unit/test_tls.py b/tests/unit/test_tls.py new file mode 100644 index 00000000..b46621cc --- /dev/null +++ b/tests/unit/test_tls.py @@ -0,0 +1,95 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +from unittest.mock import MagicMock, patch + +import pytest +from charms.tls_certificates_interface.v4.tls_certificates import Certificate, PrivateKey + +from tests.unit.certificates_helpers import example_cert_and_key +from tls import CA_CERTIFICATE_NAME, CERTIFICATE_NAME, PRIVATE_KEY_NAME, Tls + +STORAGE_PATH = "test" +CA_CERTIFICATE_PATH = f"{STORAGE_PATH}/{CA_CERTIFICATE_NAME}" +CERTIFICATE_PATH = f"{STORAGE_PATH}/{CERTIFICATE_NAME}" +PRIVATE_KEY_PATH = f"{STORAGE_PATH}/{PRIVATE_KEY_NAME}" + + +class TestTls: + patcher_get_assigned_certificate = patch( + "charms.tls_certificates_interface.v4.tls_certificates.TLSCertificatesRequiresV4.get_assigned_certificate" + ) + + @pytest.fixture(autouse=True) + def setUp(self, request): + self.mock_get_assigned_certificate = TestTls.patcher_get_assigned_certificate.start() + mock_charm = MagicMock() + self.mock_container = MagicMock() + self.tls = Tls( + charm=mock_charm, + relation_name="certs", + container=self.mock_container, + domain_name="test", + workload_storage_path=STORAGE_PATH + ) + request.addfinalizer(self.tearDown) + + @staticmethod + def tearDown() -> None: + patch.stopall() + + def test_given_get_assigned_certificate_valid_values_then_certificate_is_available_returns_true(self): # noqa: E501 + mock_cert = MagicMock(spec=Certificate) + mock_key = MagicMock(spec=PrivateKey) + self.mock_get_assigned_certificate.return_value = mock_cert, mock_key + + assert self.tls.certificate_is_available() is True + + @pytest.mark.parametrize( + "certificate, private_key", + [ + (None, None), + (None, MagicMock(spec=PrivateKey)), + (MagicMock(spec=Certificate), None), + ] + ) + def test_given_get_assigned_certificate_returns_none_then_certificate_is_available_returns_false( # noqa: E501 + self, certificate, private_key + ): + self.mock_get_assigned_certificate.return_value = certificate, private_key + + assert self.tls.certificate_is_available() is False + + def test_given_certificates_do_not_exist_when_check_and_update_certificate_then_certificates_are_stored(self): # noqa: E501 + mock_cert, mock_key = example_cert_and_key() + self.mock_get_assigned_certificate.return_value = mock_cert, mock_key + self.mock_container.exists.return_value = False + + was_updated = self.tls.check_and_update_certificate() + + assert was_updated is True + self.mock_container.push.assert_any_call(path=PRIVATE_KEY_PATH, source=str(mock_key)) + self.mock_container.push.assert_any_call( + path=CERTIFICATE_PATH, + source=str(mock_cert.certificate) + ) + self.mock_container.push.assert_any_call( + path=CA_CERTIFICATE_PATH, + source=str(mock_cert.ca) + ) + + def test_given_certificate_private_key_and_ca_certificate_exist_when_clean_up_certificates_then_certificates_are_removed(self): # noqa: E501 + self.mock_container.exists.return_value = True + + self.tls.clean_up_certificates() + + self.mock_container.remove_path.assert_any_call(path=CERTIFICATE_PATH) + self.mock_container.remove_path.assert_any_call(path=PRIVATE_KEY_PATH) + self.mock_container.remove_path.assert_any_call(path=CA_CERTIFICATE_PATH) + + def test_given_certificate_private_key_and_ca_certificate_do_not_exist_when_clean_up_certificates_then_certificates_are_removed(self): # noqa: E501 + self.mock_container.exists.return_value = False + + self.tls.clean_up_certificates() + + self.mock_container.remove_path.assert_not_called()