Skip to content

Commit

Permalink
refresh csr if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldmitry committed Oct 10, 2024
1 parent b12d143 commit f47ac3c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 42 deletions.
24 changes: 12 additions & 12 deletions lib/charms/observability_libs/v1/cert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a"
LIBAPI = 1
LIBPATCH = 13
LIBPATCH = 14

VAULT_SECRET_LABEL = "cert-handler-private-vault"

Expand Down Expand Up @@ -301,14 +301,11 @@ def __init__(
Must match metadata.yaml.
cert_subject: Custom subject. Name collisions are under the caller's responsibility.
sans: DNS names. If none are given, use FQDN.
refresh_events: an optional list of bound events which
will be observed to replace the current CSR with a new one
if there are changes in the CSR's DNS SANs or IP SANs.
Then, subsequently, replace its corresponding certificate with a new one.
refresh_events: [DEPRECATED].
"""
super().__init__(charm, key)
# use StoredState to store the hash of the CSR
# to potentially trigger a CSR renewal on `refresh_events`
# to potentially trigger a CSR renewal
self._stored.set_default(
csr_hash=None,
)
Expand Down Expand Up @@ -367,13 +364,11 @@ def __init__(
)

if refresh_events:
for ev in refresh_events:
self.framework.observe(ev, self._on_refresh_event)
logger.warn(
"DEPRECATION WARNING. `refresh_events` is now deprecated. CertHandler will automatically refresh the CSR when necessary."
)

def _on_refresh_event(self, _):
"""Replace the latest current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)
self._refresh_csr_if_needed()

def _on_upgrade_charm(self, _):
has_privkey = self.vault.get_value("private-key")
Expand All @@ -388,6 +383,11 @@ def _on_upgrade_charm(self, _):
# this will call `self.private_key` which will generate a new privkey.
self._generate_csr(renew=True)

def _refresh_csr_if_needed(self):
"""Refresh the latest current CSR with a new one if there are any SANs changes."""
if self._stored.csr_hash is not None and self._stored.csr_hash != self._csr_hash:
self._generate_csr(renew=True)

def _migrate_vault(self):
peer_backend = _RelationVaultBackend(self.charm, relation_name="peers")

Expand Down
73 changes: 43 additions & 30 deletions tests/scenario/test_cert_handler/test_cert_handler_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from cryptography import x509
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, fw):
if hostname := self._mock_san:
sans.append(hostname)

self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed])
self.ch = CertHandler(self, key="ch", sans=sans)

@property
def _mock_san(self):
Expand Down Expand Up @@ -145,6 +145,14 @@ def _cert_renew_patch():
yield patcher


@contextmanager
def _cert_generate_patch():
with patch(
"charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_creation"
) as patcher:
yield patcher


@pytest.mark.parametrize("leader", (True, False))
def test_cert_joins(ctx, certificates, leader):
with ctx.manager(
Expand Down Expand Up @@ -183,48 +191,53 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader):
assert mgr.charm.ch.private_key


def test_renew_csr_on_sans_change(ctx, certificates):
# generate a CSR
@pytest.mark.parametrize(
"event,generate_call_count",
(("update_status", 0), ("start", 0), ("install", 0), ("config_changed", 1)),
)
def test_no_renew_if_no_initial_csr_was_generated(event, generate_call_count, ctx, certificates):
with _cert_renew_patch() as renew_patch:
with _cert_generate_patch() as generate_patch:
with ctx.manager(
event,
State(leader=True, relations=[certificates]),
) as mgr:

mgr.run()
assert renew_patch.call_count == 0
assert generate_patch.call_count == generate_call_count


@patch.object(CertHandler, "_stored", MagicMock())
@pytest.mark.parametrize(
"is_relation, event",
(
(False, "start"),
(True, "changed_event"),
(False, "config_changed"),
),
)
def test_csr_renew_on_any_event(is_relation, event, ctx, certificates):
with ctx.manager(
certificates.joined_event,
State(leader=True, relations=[certificates]),
getattr(certificates, event) if is_relation else event,
State(
leader=True,
relations=[certificates],
),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

# trigger a config_changed with a modified SAN
with _sans_patch():
with ctx.manager("config_changed", state_out) as mgr:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
# assert CSR contains updated SAN
assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME}


def test_csr_no_change_on_wrong_refresh_event(ctx, certificates):
with _cert_renew_patch() as renew_patch:
with ctx.manager(
"config_changed",
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

with _sans_patch():
with _cert_renew_patch() as renew_patch:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0


def test_csr_no_change(ctx, certificates):

with ctx.manager(
Expand Down

0 comments on commit f47ac3c

Please sign in to comment.