From 00f61db1b85748a24e1ac4317dcca90bd8edf175 Mon Sep 17 00:00:00 2001 From: Michael Dmitry <33381599+michaeldmitry@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:27:45 +0300 Subject: [PATCH] Refresh CSR "whenever" needed (#110) * refresh csr if needed * use stable hash * static * sort sans * lint * comments --- .../observability_libs/v1/cert_handler.py | 42 ++++++---- .../test_cert_handler/test_cert_handler_v1.py | 76 +++++++++++-------- 2 files changed, 72 insertions(+), 46 deletions(-) diff --git a/lib/charms/observability_libs/v1/cert_handler.py b/lib/charms/observability_libs/v1/cert_handler.py index 4a1940b..26be879 100644 --- a/lib/charms/observability_libs/v1/cert_handler.py +++ b/lib/charms/observability_libs/v1/cert_handler.py @@ -32,6 +32,7 @@ Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3. """ import abc +import hashlib import ipaddress import json import socket @@ -67,7 +68,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 1 -LIBPATCH = 13 +LIBPATCH = 14 VAULT_SECRET_LABEL = "cert-handler-private-vault" @@ -301,14 +302,11 @@ def __init__( Must match metadata.yaml. cert_subject: Custom subject. Name collisions are under the caller's responsibility. sans: DNS names. If none are given, use FQDN. - refresh_events: an optional list of bound events which - will be observed to replace the current CSR with a new one - if there are changes in the CSR's DNS SANs or IP SANs. - Then, subsequently, replace its corresponding certificate with a new one. + refresh_events: [DEPRECATED]. """ super().__init__(charm, key) # use StoredState to store the hash of the CSR - # to potentially trigger a CSR renewal on `refresh_events` + # to potentially trigger a CSR renewal self._stored.set_default( csr_hash=None, ) @@ -320,8 +318,9 @@ def __init__( # Use fqdn only if no SANs were given, and drop empty/duplicate SANs sans = list(set(filter(None, (sans or [socket.getfqdn()])))) - self.sans_ip = list(filter(is_ip_address, sans)) - self.sans_dns = list(filterfalse(is_ip_address, sans)) + # sort SANS lists to avoid unnecessary csr renewals during reconciliation + self.sans_ip = sorted(filter(is_ip_address, sans)) + self.sans_dns = sorted(filterfalse(is_ip_address, sans)) if self._check_juju_supports_secrets(): vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL) @@ -367,13 +366,15 @@ def __init__( ) if refresh_events: - for ev in refresh_events: - self.framework.observe(ev, self._on_refresh_event) + logger.warn( + "DEPRECATION WARNING. `refresh_events` is now deprecated. CertHandler will automatically refresh the CSR when necessary." + ) - def _on_refresh_event(self, _): - """Replace the latest current CSR with a new one if there are any SANs changes.""" - if self._stored.csr_hash != self._csr_hash: - self._generate_csr(renew=True) + self._reconcile() + + def _reconcile(self): + """Run all logic that is independent of what event we're processing.""" + self._refresh_csr_if_needed() def _on_upgrade_charm(self, _): has_privkey = self.vault.get_value("private-key") @@ -388,6 +389,11 @@ def _on_upgrade_charm(self, _): # this will call `self.private_key` which will generate a new privkey. self._generate_csr(renew=True) + def _refresh_csr_if_needed(self): + """Refresh the current CSR with a new one if there are any SANs changes.""" + if self._stored.csr_hash is not None and self._stored.csr_hash != self._csr_hash: + self._generate_csr(renew=True) + def _migrate_vault(self): peer_backend = _RelationVaultBackend(self.charm, relation_name="peers") @@ -440,13 +446,17 @@ def enabled(self) -> bool: return True @property - def _csr_hash(self) -> int: + def _csr_hash(self) -> str: """A hash of the config that constructs the CSR. Only include here the config options that, should they change, should trigger a renewal of the CSR. """ - return hash( + + def _stable_hash(data): + return hashlib.sha256(str(data).encode()).hexdigest() + + return _stable_hash( ( tuple(self.sans_dns), tuple(self.sans_ip), diff --git a/tests/scenario/test_cert_handler/test_cert_handler_v1.py b/tests/scenario/test_cert_handler/test_cert_handler_v1.py index 78defb8..212cfa4 100644 --- a/tests/scenario/test_cert_handler/test_cert_handler_v1.py +++ b/tests/scenario/test_cert_handler/test_cert_handler_v1.py @@ -4,7 +4,7 @@ import sys from contextlib import contextmanager from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from cryptography import x509 @@ -36,7 +36,7 @@ def __init__(self, fw): if hostname := self._mock_san: sans.append(hostname) - self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed]) + self.ch = CertHandler(self, key="ch", sans=sans) @property def _mock_san(self): @@ -145,6 +145,14 @@ def _cert_renew_patch(): yield patcher +@contextmanager +def _cert_generate_patch(): + with patch( + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation" + ) as patcher: + yield patcher + + @pytest.mark.parametrize("leader", (True, False)) def test_cert_joins(ctx, certificates, leader): with ctx.manager( @@ -183,48 +191,56 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader): assert mgr.charm.ch.private_key -def test_renew_csr_on_sans_change(ctx, certificates): - # generate a CSR +# CertHandler generates a cert on `config_changed` event +@pytest.mark.parametrize( + "event,expected_generate_calls", + (("update_status", 0), ("start", 0), ("install", 0), ("config_changed", 1)), +) +def test_no_renew_if_no_initial_csr_was_generated( + event, expected_generate_calls, ctx, certificates +): + with _cert_renew_patch() as renew_patch: + with _cert_generate_patch() as generate_patch: + with ctx.manager( + event, + State(leader=True, relations=[certificates]), + ) as mgr: + + mgr.run() + assert renew_patch.call_count == 0 + assert generate_patch.call_count == expected_generate_calls + + +@patch.object(CertHandler, "_stored", MagicMock()) +@pytest.mark.parametrize( + "is_relation, event", + ( + (False, "start"), + (True, "changed_event"), + (False, "config_changed"), + ), +) +def test_csr_renew_on_any_event(is_relation, event, ctx, certificates): with ctx.manager( - certificates.joined_event, - State(leader=True, relations=[certificates]), + getattr(certificates, event) if is_relation else event, + State( + leader=True, + relations=[certificates], + ), ) as mgr: charm = mgr.charm state_out = mgr.run() orig_csr = get_csr_obj(charm.ch._csr) assert get_sans_from_csr(orig_csr) == {socket.getfqdn()} - # trigger a config_changed with a modified SAN with _sans_patch(): - with ctx.manager("config_changed", state_out) as mgr: + with ctx.manager("update_status", state_out) as mgr: charm = mgr.charm state_out = mgr.run() csr = get_csr_obj(charm.ch._csr) - # assert CSR contains updated SAN assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME} -def test_csr_no_change_on_wrong_refresh_event(ctx, certificates): - with _cert_renew_patch() as renew_patch: - with ctx.manager( - "config_changed", - State(leader=True, relations=[certificates]), - ) as mgr: - charm = mgr.charm - state_out = mgr.run() - orig_csr = get_csr_obj(charm.ch._csr) - assert get_sans_from_csr(orig_csr) == {socket.getfqdn()} - - with _sans_patch(): - with _cert_renew_patch() as renew_patch: - with ctx.manager("update_status", state_out) as mgr: - charm = mgr.charm - state_out = mgr.run() - csr = get_csr_obj(charm.ch._csr) - assert get_sans_from_csr(csr) == {socket.getfqdn()} - assert renew_patch.call_count == 0 - - def test_csr_no_change(ctx, certificates): with ctx.manager(