From 5994a027c1f5aafc1c92a0f528209890c96ebac1 Mon Sep 17 00:00:00 2001 From: Github Actions Date: Thu, 1 Feb 2024 16:04:26 +0000 Subject: [PATCH] chore: update charm libraries --- .../v0/certificate_transfer.py | 4 +- lib/charms/loki_k8s/v1/loki_push_api.py | 270 +++++++++- .../prometheus_k8s/v0/prometheus_scrape.py | 28 +- .../v1/prometheus_remote_write.py | 11 +- .../v2/tls_certificates.py | 506 +++++++++++++----- 5 files changed, 651 insertions(+), 168 deletions(-) diff --git a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py index edad1d9..44ddfda 100644 --- a/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py +++ b/lib/charms/certificate_transfer_interface/v0/certificate_transfer.py @@ -97,7 +97,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): import logging from typing import List -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent from ops.framework import EventBase, EventSource, Handle, Object @@ -109,7 +109,7 @@ def _on_certificate_removed(self, event: CertificateRemovedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 4 +LIBPATCH = 5 PYDEPS = ["jsonschema"] diff --git a/lib/charms/loki_k8s/v1/loki_push_api.py b/lib/charms/loki_k8s/v1/loki_push_api.py index bcc3e6a..2a2e585 100644 --- a/lib/charms/loki_k8s/v1/loki_push_api.py +++ b/lib/charms/loki_k8s/v1/loki_push_api.py @@ -20,6 +20,10 @@ send telemetry, such as logs, to Loki through a Log Proxy by implementing the consumer side of the `loki_push_api` relation interface. +- `LogForwarder`: This object can be used by any Charmed Operator which needs to send the workload +standard output (stdout) through Pebble's log forwarding mechanism, to Loki endpoints through the +`loki_push_api` relation interface. + Filtering logs in Loki is largely performed on the basis of labels. In the Juju ecosystem, Juju topology labels are used to uniquely identify the workload which generates telemetry like logs. @@ -57,7 +61,7 @@ Subsequently, a Loki charm may instantiate the `LokiPushApiProvider` in its constructor as follows: - from charms.loki_k8s.v0.loki_push_api import LokiPushApiProvider + from charms.loki_k8s.v1.loki_push_api import LokiPushApiProvider from loki_server import LokiServer ... @@ -163,7 +167,7 @@ def __init__(self, *args): sends logs). ```python -from charms.loki_k8s.v0.loki_push_api import LokiPushApiConsumer +from charms.loki_k8s.v1.loki_push_api import LokiPushApiConsumer class LokiClientCharm(CharmBase): @@ -349,6 +353,45 @@ def _promtail_error(self, event): ) ``` +## LogForwarder class Usage + +Let's say that we have a charm's workload that writes logs to the standard output (stdout), +and we need to send those logs to a workload implementing the `loki_push_api` interface, +such as `Loki` or `Grafana Agent`. To know how to reach a Loki instance, a charm would +typically use the `loki_push_api` interface. + +Use the `LogForwarder` class by instantiating it in the `__init__` method of the charm: + +```python +from charms.loki_k8s.v1.loki_push_api import LogForwarder + +... + + def __init__(self, *args): + ... + self._log_forwarder = LogForwarder( + self, + relation_name="logging" # optional, defaults to `logging` + ) +``` + +The `LogForwarder` by default will observe relation events on the `logging` endpoint and +enable/disable log forwarding automatically. +Next, modify the `metadata.yaml` file to add: + +The `log-forwarding` relation in the `requires` section: +```yaml +requires: + logging: + interface: loki_push_api + optional: true +``` + +Once the LogForwader class is implemented in your charm and the relation (implementing the +`loki_push_api` interface) is active and healthy, the library will inject a Pebble layer in +each workload container the charm has access to, to configure Pebble's log forwarding +feature and start sending logs to Loki. + ## Alerting Rules This charm library also supports gathering alerting rules from all related Loki client @@ -463,6 +506,7 @@ def _alert_rules_error(self, event): WorkloadEvent, ) from ops.framework import EventBase, EventSource, Object, ObjectEvents +from ops.jujuversion import JujuVersion from ops.model import Container, ModelError, Relation from ops.pebble import APIError, ChangeError, Layer, PathError, ProtocolError @@ -474,7 +518,7 @@ def _alert_rules_error(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 1 +LIBPATCH = 3 logger = logging.getLogger(__name__) @@ -2045,7 +2089,21 @@ def _download_and_push_promtail_to_workload( - "binsha": sha256 sum of unpacked promtail binary container: container into which promtail is to be uploaded. """ - with request.urlopen(promtail_info["url"]) as r: + # Check for Juju proxy variables and fall back to standard ones if not set + proxies: Optional[Dict[str, str]] = {} + if proxies and os.environ.get("JUJU_CHARM_HTTP_PROXY"): + proxies.update({"http": os.environ["JUJU_CHARM_HTTP_PROXY"]}) + if proxies and os.environ.get("JUJU_CHARM_HTTPS_PROXY"): + proxies.update({"https": os.environ["JUJU_CHARM_HTTPS_PROXY"]}) + if proxies and os.environ.get("JUJU_CHARM_NO_PROXY"): + proxies.update({"no_proxy": os.environ["JUJU_CHARM_NO_PROXY"]}) + else: + proxies = None + + proxy_handler = request.ProxyHandler(proxies) + opener = request.build_opener(proxy_handler) + + with opener.open(promtail_info["url"]) as r: file_bytes = r.read() file_path = os.path.join(BINARY_DIR, promtail_info["filename"] + ".gz") with open(file_path, "wb") as f: @@ -2313,6 +2371,210 @@ def _containers(self) -> Dict[str, Container]: return {cont: self._charm.unit.get_container(cont) for cont in self._logs_scheme.keys()} +class _PebbleLogClient: + @staticmethod + def check_juju_version() -> bool: + """Make sure the Juju version supports Log Forwarding.""" + juju_version = JujuVersion.from_environ() + if not juju_version > JujuVersion(version=str("3.3")): + msg = f"Juju version {juju_version} does not support Pebble log forwarding. Juju >= 3.4 is needed." + logger.warning(msg) + return False + return True + + @staticmethod + def _build_log_target( + unit_name: str, loki_endpoint: str, topology: JujuTopology, enable: bool + ) -> Dict: + """Build a log target for the log forwarding Pebble layer. + + Log target's syntax for enabling/disabling forwarding is explained here: + https://github.com/canonical/pebble?tab=readme-ov-file#log-forwarding + """ + services_value = ["all"] if enable else ["-all"] + + log_target = { + "override": "replace", + "services": services_value, + "type": "loki", + "location": loki_endpoint, + } + if enable: + log_target.update( + { + "labels": { + "product": "Juju", + "charm": topology._charm_name, + "juju_model": topology._model, + "juju_model_uuid": topology._model_uuid, + "juju_application": topology._application, + "juju_unit": topology._unit, + }, + } + ) + + return {unit_name: log_target} + + @staticmethod + def _build_log_targets( + loki_endpoints: Optional[Dict[str, str]], topology: JujuTopology, enable: bool + ): + """Build all the targets for the log forwarding Pebble layer.""" + targets = {} + if not loki_endpoints: + return targets + + for unit_name, endpoint in loki_endpoints.items(): + targets.update( + _PebbleLogClient._build_log_target( + unit_name=unit_name, + loki_endpoint=endpoint, + topology=topology, + enable=enable, + ) + ) + return targets + + @staticmethod + def disable_inactive_endpoints( + container: Container, active_endpoints: Dict[str, str], topology: JujuTopology + ): + """Disable forwarding for inactive endpoints by checking against the Pebble plan.""" + pebble_layer = container.get_plan().to_dict().get("log-targets", None) + if not pebble_layer: + return + + for unit_name, target in pebble_layer.items(): + # If the layer is a disabled log forwarding endpoint, skip it + if "-all" in target["services"]: # pyright: ignore + continue + + if unit_name not in active_endpoints: + layer = Layer( + { # pyright: ignore + "log-targets": _PebbleLogClient._build_log_targets( + loki_endpoints={unit_name: "(removed)"}, + topology=topology, + enable=False, + ) + } + ) + container.add_layer(f"{container.name}-log-forwarding", layer=layer, combine=True) + + @staticmethod + def enable_endpoints( + container: Container, active_endpoints: Dict[str, str], topology: JujuTopology + ): + """Enable forwarding for the specified Loki endpoints.""" + layer = Layer( + { # pyright: ignore + "log-targets": _PebbleLogClient._build_log_targets( + loki_endpoints=active_endpoints, + topology=topology, + enable=True, + ) + } + ) + container.add_layer(f"{container.name}-log-forwarding", layer, combine=True) + + +class LogForwarder(ConsumerBase): + """Forward the standard outputs of all workloads operated by a charm to one or multiple Loki endpoints.""" + + def __init__( + self, + charm: CharmBase, + *, + relation_name: str = DEFAULT_RELATION_NAME, + alert_rules_path: str = DEFAULT_ALERT_RULES_RELATIVE_PATH, + recursive: bool = True, + skip_alert_topology_labeling: bool = False, + ): + _PebbleLogClient.check_juju_version() + super().__init__( + charm, relation_name, alert_rules_path, recursive, skip_alert_topology_labeling + ) + self._charm = charm + self._relation_name = relation_name + + on = self._charm.on[self._relation_name] + self.framework.observe(on.relation_joined, self._update_logging) + self.framework.observe(on.relation_changed, self._update_logging) + self.framework.observe(on.relation_departed, self._update_logging) + self.framework.observe(on.relation_broken, self._update_logging) + + def _update_logging(self, _): + """Update the log forwarding to match the active Loki endpoints.""" + loki_endpoints = {} + + # Get the endpoints from relation data + for relation in self._charm.model.relations[self._relation_name]: + loki_endpoints.update(self._fetch_endpoints(relation)) + + if not loki_endpoints: + logger.warning("No Loki endpoints available") + return + + for container in self._charm.unit.containers.values(): + _PebbleLogClient.disable_inactive_endpoints( + container=container, + active_endpoints=loki_endpoints, + topology=self.topology, + ) + _PebbleLogClient.enable_endpoints( + container=container, active_endpoints=loki_endpoints, topology=self.topology + ) + + def is_ready(self, relation: Optional[Relation] = None): + """Check if the relation is active and healthy.""" + if not relation: + relations = self._charm.model.relations[self._relation_name] + if not relations: + return False + return all(self.is_ready(relation) for relation in relations) + + try: + if self._extract_urls(relation): + return True + return False + except (KeyError, json.JSONDecodeError): + return False + + def _extract_urls(self, relation: Relation) -> Dict[str, str]: + """Default getter function to extract Loki endpoints from a relation. + + Returns: + A dictionary of remote units and the respective Loki endpoint. + { + "loki/0": "http://loki:3100/loki/api/v1/push", + "another-loki/0": "http://another-loki:3100/loki/api/v1/push", + } + """ + endpoints: Dict = {} + + for unit in relation.units: + endpoint = relation.data[unit]["endpoint"] + deserialized_endpoint = json.loads(endpoint) + url = deserialized_endpoint["url"] + endpoints[unit.name] = url + + return endpoints + + def _fetch_endpoints(self, relation: Relation) -> Dict[str, str]: + """Fetch Loki Push API endpoints from relation data using the endpoints getter.""" + endpoints: Dict = {} + + if not self.is_ready(relation): + logger.warning(f"The relation '{relation.name}' is not ready yet.") + return endpoints + + # if the code gets here, the function won't raise anymore because it's + # also called in is_ready() + endpoints = self._extract_urls(relation) + + return endpoints + + class CosTool: """Uses cos-tool to inject label matchers into alert rule expressions and validate rules.""" diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index e4297aa..665af88 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 42 +LIBPATCH = 44 PYDEPS = ["cosl"] @@ -386,6 +386,7 @@ def _on_scrape_targets_changed(self, event): "basic_auth", "tls_config", "authorization", + "params", } DEFAULT_JOB = { "metrics_path": "/metrics", @@ -764,7 +765,7 @@ def _validate_relation_by_interface_and_direction( actual_relation_interface = relation.interface_name if actual_relation_interface != expected_relation_interface: raise RelationInterfaceMismatchError( - relation_name, expected_relation_interface, actual_relation_interface + relation_name, expected_relation_interface, actual_relation_interface or "None" ) if expected_relation_role == RelationRole.provides: @@ -857,7 +858,7 @@ class MonitoringEvents(ObjectEvents): class MetricsEndpointConsumer(Object): """A Prometheus based Monitoring service.""" - on = MonitoringEvents() + on = MonitoringEvents() # pyright: ignore def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME): """A Prometheus based Monitoring service. @@ -1014,7 +1015,6 @@ def alerts(self) -> dict: try: scrape_metadata = json.loads(relation.data[relation.app]["scrape_metadata"]) identifier = JujuTopology.from_dict(scrape_metadata).identifier - alerts[identifier] = self._tool.apply_label_matchers(alert_rules) # type: ignore except KeyError as e: logger.debug( @@ -1029,6 +1029,10 @@ def alerts(self) -> dict: ) continue + # We need to append the relation info to the identifier. This is to allow for cases for there are two + # relations which eventually scrape the same application. Issue #551. + identifier = f"{identifier}_{relation.name}_{relation.id}" + alerts[identifier] = alert_rules _, errmsg = self._tool.validate_alert_rules(alert_rules) @@ -1294,7 +1298,7 @@ def _resolve_dir_against_charm_path(charm: CharmBase, *path_elements: str) -> st class MetricsEndpointProvider(Object): """A metrics endpoint for Prometheus.""" - on = MetricsEndpointProviderEvents() + on = MetricsEndpointProviderEvents() # pyright: ignore def __init__( self, @@ -1836,14 +1840,16 @@ def _set_prometheus_data(self, event): return jobs = [] + _type_convert_stored( - self._stored.jobs + self._stored.jobs # pyright: ignore ) # list of scrape jobs, one per relation for relation in self.model.relations[self._target_relation]: targets = self._get_targets(relation) if targets and relation.app: jobs.append(self._static_scrape_job(targets, relation.app.name)) - groups = [] + _type_convert_stored(self._stored.alert_rules) # list of alert rule groups + groups = [] + _type_convert_stored( + self._stored.alert_rules # pyright: ignore + ) # list of alert rule groups for relation in self.model.relations[self._alert_rules_relation]: unit_rules = self._get_alert_rules(relation) if unit_rules and relation.app: @@ -1895,7 +1901,7 @@ def set_target_job_data(self, targets: dict, app_name: str, **kwargs) -> None: jobs.append(updated_job) relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _on_prometheus_targets_departed(self, event): @@ -1947,7 +1953,7 @@ def remove_prometheus_jobs(self, job_name: str, unit_name: Optional[str] = ""): relation.data[self._charm.app]["scrape_jobs"] = json.dumps(jobs) - if not _type_convert_stored(self._stored.jobs) == jobs: + if not _type_convert_stored(self._stored.jobs) == jobs: # pyright: ignore self._stored.jobs = jobs def _job_name(self, appname) -> str: @@ -2126,7 +2132,7 @@ def set_alert_rule_data(self, name: str, unit_rules: dict, label_rules: bool = T groups.append(updated_group) relation.data[self._charm.app]["alert_rules"] = json.dumps({"groups": groups}) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _on_alert_rules_departed(self, event): @@ -2176,7 +2182,7 @@ def remove_alert_rules(self, group_name: str, unit_name: str) -> None: json.dumps({"groups": groups}) if groups else "{}" ) - if not _type_convert_stored(self._stored.alert_rules) == groups: + if not _type_convert_stored(self._stored.alert_rules) == groups: # pyright: ignore self._stored.alert_rules = groups def _get_alert_rules(self, relation) -> dict: diff --git a/lib/charms/prometheus_k8s/v1/prometheus_remote_write.py b/lib/charms/prometheus_k8s/v1/prometheus_remote_write.py index 8965816..cf24b9f 100644 --- a/lib/charms/prometheus_k8s/v1/prometheus_remote_write.py +++ b/lib/charms/prometheus_k8s/v1/prometheus_remote_write.py @@ -46,7 +46,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 4 PYDEPS = ["cosl"] @@ -211,7 +211,7 @@ def _validate_relation_by_interface_and_direction( actual_relation_interface = relation.interface_name if actual_relation_interface != expected_relation_interface: raise RelationInterfaceMismatchError( - relation_name, expected_relation_interface, actual_relation_interface + relation_name, expected_relation_interface, actual_relation_interface or "None" ) if expected_relation_role == RelationRole.provides: @@ -394,7 +394,7 @@ def __init__(self, *args): ``` """ - on = PrometheusRemoteWriteConsumerEvents() + on = PrometheusRemoteWriteConsumerEvents() # pyright: ignore def __init__( self, @@ -458,7 +458,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.endpoints_changed.emit(relation_id=event.relation.id) def _handle_endpoints_changed(self, event: RelationEvent) -> None: - if self._charm.unit.is_leader(): + if self._charm.unit.is_leader() and event.app is not None: ev = json.loads(event.relation.data[event.app].get("event", "{}")) if ev: @@ -591,7 +591,7 @@ def __init__(self, *args): name to differentiate between "incoming" and "outgoing" remote write interactions is necessary. """ - on = PrometheusRemoteWriteProviderEvents() + on = PrometheusRemoteWriteProviderEvents() # pyright: ignore def __init__( self, @@ -749,6 +749,7 @@ def alerts(self) -> dict: _, errmsg = self._tool.validate_alert_rules(alert_rules) if errmsg: + logger.error(f"Invalid alert rule file: {errmsg}") if self._charm.unit.is_leader(): data = json.loads(relation.data[self._charm.app].get("event", "{}")) data["errors"] = errmsg diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index f4a0836..5992faf 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -286,8 +286,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import ( CharmBase, CharmEvents, @@ -298,7 +297,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import Relation, SecretNotFoundError +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" @@ -308,13 +307,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 16 +LIBPATCH = 24 PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -335,7 +334,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -346,7 +348,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -536,22 +538,31 @@ def restore(self, snapshot: dict): class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: """Returns snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): """Restores snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -588,26 +599,63 @@ def restore(self, snapshot: dict): self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: """Loads relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time + ) + + +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. + + Args: + certificate (str): x509 certificate as a string + + Returns: + Optional[datetime]: Expiry datetime or None + """ + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after + except ValueError: + logger.warning("Could not load certificate.") + return None + + def generate_ca( private_key: bytes, subject: str, @@ -678,6 +726,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generates a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -685,6 +832,7 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: """Generates a TLS certificate based on a CSR. @@ -695,6 +843,7 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate @@ -713,52 +862,24 @@ def generate_certificate( .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=validity)) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension( - x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False - ) - .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), - ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) - - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext - - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -817,9 +938,11 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes @@ -896,6 +1019,38 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1171,15 +1326,19 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] + requirer_unit_certificate_requests = [ + { + "csr": certificate_creation_request["certificate_signing_request"], + "is_ca": certificate_creation_request.get("ca", False), + } for certificate_creation_request in requirer_csrs ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_unit_certificate_requests: + if certificate_request["csr"] not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, + certificate_signing_request=certificate_request["csr"], relation_id=event.relation.id, + is_ca=certificate_request["is_ca"], ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) @@ -1217,12 +1376,24 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) - def get_requirer_csrs_with_no_certs( + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Filters the requirer's units csrs. + """Returns CSR's for which no certificate has been issued. - Keeps the ones for which no certificate was provided. + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] Args: relation_id (int): Relation id @@ -1239,6 +1410,7 @@ def get_requirer_csrs_with_no_certs( if not self.certificate_issued_for_csr( app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, ): csrs_without_certs.append(csr) if csrs_without_certs: @@ -1285,17 +1457,21 @@ def get_requirer_csrs( ) return unit_csr_mappings - def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: """Checks whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. - + relation_id (Optional[int]): Relation ID Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates()[app_name] + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] for issued_pair in issued_certificates_per_csr: if "csr" in issued_pair and issued_pair["csr"] == csr: return csr_matches_certificate(csr, issued_pair["certificate"]) @@ -1337,8 +1513,17 @@ def __init__( self.framework.observe(charm.on.update_status, self._on_update_status) @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer's CSRs from relation data.""" + def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + """Returns list of requirer's CSRs from relation unit data. + + Example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") @@ -1361,11 +1546,12 @@ def _provider_certificates(self) -> List[Dict[str, str]]: return [] return provider_relation_data.get("certificates", []) - def _add_requirer_csr(self, csr: str) -> None: + def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: """Adds CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1376,7 +1562,10 @@ def _add_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} + new_csr_dict: Dict[str, Union[bool, str]] = { + "certificate_signing_request": csr, + "ca": is_ca, + } if new_csr_dict in self._requirer_csrs: logger.info("CSR already in relation data - Doing nothing") return @@ -1400,18 +1589,22 @@ def _remove_requirer_csr(self, csr: str) -> None: f"The certificate request can't be completed" ) requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not requirer_csrs: + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) + for requirer_csr in requirer_csrs: + if requirer_csr["certificate_signing_request"] == csr: + requirer_csrs.remove(requirer_csr) relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1422,7 +1615,7 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1466,6 +1659,92 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") + def get_assigned_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + final_list.append(cert) + return final_list + + def get_expiring_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + expiry_time = _get_certificate_expiry_time(cert["certificate"]) + if not expiry_time: + continue + expiry_notification_time = expiry_time - timedelta( + hours=self.expiry_notification_time + ) + if datetime.utcnow() > expiry_notification_time: + final_list.append(cert) + return final_list + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[Dict[str, Union[bool, str]]]: + """Gets the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + + Returns: + List of CSR dictionaries. For example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ + final_list = [] + for csr in self._requirer_csrs: + assert isinstance(csr["certificate_signing_request"], str) + cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + final_list.append(csr) + + return final_list + @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: """Checks whether relation data is valid based on json schema. @@ -1676,68 +1955,3 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: certificate=certificate_dict["certificate"], expiry=expiry_time.isoformat(), ) - - -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - expects to get the original string representations. - - Args: - csr (str): Certificate Signing Request - cert (str): Certificate - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if csr_object.subject != cert_object.subject: - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time - ) - - -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. - - Args: - certificate (str): x509 certificate as a string - - Returns: - Optional[datetime]: Expiry datetime or None - """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after - except ValueError: - logger.warning("Could not load certificate.") - return None