Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh CSR "whenever" needed #110

Merged
merged 6 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions lib/charms/observability_libs/v1/cert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3.
"""
import abc
import hashlib
import ipaddress
import json
import socket
Expand Down Expand Up @@ -67,7 +68,7 @@

LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a"
LIBAPI = 1
LIBPATCH = 13
LIBPATCH = 14

VAULT_SECRET_LABEL = "cert-handler-private-vault"

Expand Down Expand Up @@ -301,14 +302,11 @@ def __init__(
Must match metadata.yaml.
cert_subject: Custom subject. Name collisions are under the caller's responsibility.
sans: DNS names. If none are given, use FQDN.
refresh_events: an optional list of bound events which
will be observed to replace the current CSR with a new one
if there are changes in the CSR's DNS SANs or IP SANs.
Then, subsequently, replace its corresponding certificate with a new one.
refresh_events: [DEPRECATED].
"""
super().__init__(charm, key)
# use StoredState to store the hash of the CSR
# to potentially trigger a CSR renewal on `refresh_events`
# to potentially trigger a CSR renewal
self._stored.set_default(
csr_hash=None,
)
Expand All @@ -320,8 +318,9 @@ def __init__(

# Use fqdn only if no SANs were given, and drop empty/duplicate SANs
sans = list(set(filter(None, (sans or [socket.getfqdn()]))))
self.sans_ip = list(filter(is_ip_address, sans))
self.sans_dns = list(filterfalse(is_ip_address, sans))
# sort SANS lists to avoid unnecessary csr renewals during reconciliation
self.sans_ip = sorted(filter(is_ip_address, sans))
michaeldmitry marked this conversation as resolved.
Show resolved Hide resolved
self.sans_dns = sorted(filterfalse(is_ip_address, sans))

if self._check_juju_supports_secrets():
vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL)
Expand Down Expand Up @@ -367,13 +366,15 @@ def __init__(
)

if refresh_events:
for ev in refresh_events:
self.framework.observe(ev, self._on_refresh_event)
logger.warn(
"DEPRECATION WARNING. `refresh_events` is now deprecated. CertHandler will automatically refresh the CSR when necessary."
)

def _on_refresh_event(self, _):
"""Replace the latest current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)
self._reconcile()

def _reconcile(self):
"""Run all logic that is independent of what event we're processing."""
self._refresh_csr_if_needed()
michaeldmitry marked this conversation as resolved.
Show resolved Hide resolved

def _on_upgrade_charm(self, _):
has_privkey = self.vault.get_value("private-key")
Expand All @@ -388,6 +389,11 @@ def _on_upgrade_charm(self, _):
# this will call `self.private_key` which will generate a new privkey.
self._generate_csr(renew=True)

def _refresh_csr_if_needed(self):
"""Refresh the current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash is not None and self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)

def _migrate_vault(self):
peer_backend = _RelationVaultBackend(self.charm, relation_name="peers")

Expand Down Expand Up @@ -440,13 +446,17 @@ def enabled(self) -> bool:
return True

@property
def _csr_hash(self) -> int:
def _csr_hash(self) -> str:
"""A hash of the config that constructs the CSR.

Only include here the config options that, should they change, should trigger a renewal of
the CSR.
"""
return hash(

def _stable_hash(data):
return hashlib.sha256(str(data).encode()).hexdigest()

return _stable_hash(
(
tuple(self.sans_dns),
tuple(self.sans_ip),
Expand Down
76 changes: 46 additions & 30 deletions tests/scenario/test_cert_handler/test_cert_handler_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from cryptography import x509
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, fw):
if hostname := self._mock_san:
sans.append(hostname)

self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed])
self.ch = CertHandler(self, key="ch", sans=sans)

@property
def _mock_san(self):
Expand Down Expand Up @@ -145,6 +145,14 @@ def _cert_renew_patch():
yield patcher


@contextmanager
def _cert_generate_patch():
with patch(
"charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation"
) as patcher:
yield patcher


@pytest.mark.parametrize("leader", (True, False))
def test_cert_joins(ctx, certificates, leader):
with ctx.manager(
Expand Down Expand Up @@ -183,48 +191,56 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader):
assert mgr.charm.ch.private_key


def test_renew_csr_on_sans_change(ctx, certificates):
# generate a CSR
# CertHandler generates a cert on `config_changed` event
@pytest.mark.parametrize(
"event,expected_generate_calls",
(("update_status", 0), ("start", 0), ("install", 0), ("config_changed", 1)),
michaeldmitry marked this conversation as resolved.
Show resolved Hide resolved
)
def test_no_renew_if_no_initial_csr_was_generated(
event, expected_generate_calls, ctx, certificates
):
with _cert_renew_patch() as renew_patch:
with _cert_generate_patch() as generate_patch:
with ctx.manager(
event,
State(leader=True, relations=[certificates]),
) as mgr:

mgr.run()
assert renew_patch.call_count == 0
assert generate_patch.call_count == expected_generate_calls


@patch.object(CertHandler, "_stored", MagicMock())
@pytest.mark.parametrize(
"is_relation, event",
(
(False, "start"),
(True, "changed_event"),
(False, "config_changed"),
),
)
def test_csr_renew_on_any_event(is_relation, event, ctx, certificates):
with ctx.manager(
certificates.joined_event,
State(leader=True, relations=[certificates]),
getattr(certificates, event) if is_relation else event,
State(
leader=True,
relations=[certificates],
),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

# trigger a config_changed with a modified SAN
with _sans_patch():
with ctx.manager("config_changed", state_out) as mgr:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
# assert CSR contains updated SAN
assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME}


def test_csr_no_change_on_wrong_refresh_event(ctx, certificates):
with _cert_renew_patch() as renew_patch:
with ctx.manager(
"config_changed",
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

with _sans_patch():
with _cert_renew_patch() as renew_patch:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0


def test_csr_no_change(ctx, certificates):

with ctx.manager(
Expand Down
Loading