Skip to content

Commit

Permalink
Refresh CSR "whenever" needed (#110)
Browse files Browse the repository at this point in the history
* refresh csr if needed

* use stable hash

* static

* sort sans

* lint

* comments
  • Loading branch information
michaeldmitry authored Oct 11, 2024
1 parent b12d143 commit 00f61db
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 46 deletions.
42 changes: 26 additions & 16 deletions lib/charms/observability_libs/v1/cert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3.
"""
import abc
import hashlib
import ipaddress
import json
import socket
Expand Down Expand Up @@ -67,7 +68,7 @@

LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a"
LIBAPI = 1
LIBPATCH = 13
LIBPATCH = 14

VAULT_SECRET_LABEL = "cert-handler-private-vault"

Expand Down Expand Up @@ -301,14 +302,11 @@ def __init__(
Must match metadata.yaml.
cert_subject: Custom subject. Name collisions are under the caller's responsibility.
sans: DNS names. If none are given, use FQDN.
refresh_events: an optional list of bound events which
will be observed to replace the current CSR with a new one
if there are changes in the CSR's DNS SANs or IP SANs.
Then, subsequently, replace its corresponding certificate with a new one.
refresh_events: [DEPRECATED].
"""
super().__init__(charm, key)
# use StoredState to store the hash of the CSR
# to potentially trigger a CSR renewal on `refresh_events`
# to potentially trigger a CSR renewal
self._stored.set_default(
csr_hash=None,
)
Expand All @@ -320,8 +318,9 @@ def __init__(

# Use fqdn only if no SANs were given, and drop empty/duplicate SANs
sans = list(set(filter(None, (sans or [socket.getfqdn()]))))
self.sans_ip = list(filter(is_ip_address, sans))
self.sans_dns = list(filterfalse(is_ip_address, sans))
# sort SANS lists to avoid unnecessary csr renewals during reconciliation
self.sans_ip = sorted(filter(is_ip_address, sans))
self.sans_dns = sorted(filterfalse(is_ip_address, sans))

if self._check_juju_supports_secrets():
vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL)
Expand Down Expand Up @@ -367,13 +366,15 @@ def __init__(
)

if refresh_events:
for ev in refresh_events:
self.framework.observe(ev, self._on_refresh_event)
logger.warn(
"DEPRECATION WARNING. `refresh_events` is now deprecated. CertHandler will automatically refresh the CSR when necessary."
)

def _on_refresh_event(self, _):
"""Replace the latest current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)
self._reconcile()

def _reconcile(self):
"""Run all logic that is independent of what event we're processing."""
self._refresh_csr_if_needed()

def _on_upgrade_charm(self, _):
has_privkey = self.vault.get_value("private-key")
Expand All @@ -388,6 +389,11 @@ def _on_upgrade_charm(self, _):
# this will call `self.private_key` which will generate a new privkey.
self._generate_csr(renew=True)

def _refresh_csr_if_needed(self):
"""Refresh the current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash is not None and self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)

def _migrate_vault(self):
peer_backend = _RelationVaultBackend(self.charm, relation_name="peers")

Expand Down Expand Up @@ -440,13 +446,17 @@ def enabled(self) -> bool:
return True

@property
def _csr_hash(self) -> int:
def _csr_hash(self) -> str:
"""A hash of the config that constructs the CSR.
Only include here the config options that, should they change, should trigger a renewal of
the CSR.
"""
return hash(

def _stable_hash(data):
return hashlib.sha256(str(data).encode()).hexdigest()

return _stable_hash(
(
tuple(self.sans_dns),
tuple(self.sans_ip),
Expand Down
76 changes: 46 additions & 30 deletions tests/scenario/test_cert_handler/test_cert_handler_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from cryptography import x509
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, fw):
if hostname := self._mock_san:
sans.append(hostname)

self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed])
self.ch = CertHandler(self, key="ch", sans=sans)

@property
def _mock_san(self):
Expand Down Expand Up @@ -145,6 +145,14 @@ def _cert_renew_patch():
yield patcher


@contextmanager
def _cert_generate_patch():
with patch(
"charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation"
) as patcher:
yield patcher


@pytest.mark.parametrize("leader", (True, False))
def test_cert_joins(ctx, certificates, leader):
with ctx.manager(
Expand Down Expand Up @@ -183,48 +191,56 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader):
assert mgr.charm.ch.private_key


def test_renew_csr_on_sans_change(ctx, certificates):
# generate a CSR
# CertHandler generates a cert on `config_changed` event
@pytest.mark.parametrize(
"event,expected_generate_calls",
(("update_status", 0), ("start", 0), ("install", 0), ("config_changed", 1)),
)
def test_no_renew_if_no_initial_csr_was_generated(
event, expected_generate_calls, ctx, certificates
):
with _cert_renew_patch() as renew_patch:
with _cert_generate_patch() as generate_patch:
with ctx.manager(
event,
State(leader=True, relations=[certificates]),
) as mgr:

mgr.run()
assert renew_patch.call_count == 0
assert generate_patch.call_count == expected_generate_calls


@patch.object(CertHandler, "_stored", MagicMock())
@pytest.mark.parametrize(
"is_relation, event",
(
(False, "start"),
(True, "changed_event"),
(False, "config_changed"),
),
)
def test_csr_renew_on_any_event(is_relation, event, ctx, certificates):
with ctx.manager(
certificates.joined_event,
State(leader=True, relations=[certificates]),
getattr(certificates, event) if is_relation else event,
State(
leader=True,
relations=[certificates],
),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

# trigger a config_changed with a modified SAN
with _sans_patch():
with ctx.manager("config_changed", state_out) as mgr:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
# assert CSR contains updated SAN
assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME}


def test_csr_no_change_on_wrong_refresh_event(ctx, certificates):
with _cert_renew_patch() as renew_patch:
with ctx.manager(
"config_changed",
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

with _sans_patch():
with _cert_renew_patch() as renew_patch:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0


def test_csr_no_change(ctx, certificates):

with ctx.manager(
Expand Down

0 comments on commit 00f61db

Please sign in to comment.