Skip to content

Commit

Permalink
Fix certificate refresh and add e2e tests (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
petrutlucian94 authored Nov 6, 2024
1 parent 687df33 commit 13875ac
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 17 deletions.
54 changes: 47 additions & 7 deletions src/k8s/pkg/k8sd/api/certificates_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package api

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509/pkix"
"encoding/base64"
"fmt"
"math"
"math/rand"
"math/big"
"net"
"net/http"
"path/filepath"
Expand All @@ -29,7 +33,11 @@ import (
)

func (e *Endpoints) postRefreshCertsPlan(s state.State, r *http.Request) response.Response {
seed := rand.Intn(math.MaxInt)
seedBigInt, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt))
if err != nil {
return response.InternalError(fmt.Errorf("failed to generate seed: %w", err))
}
seed := int(seedBigInt.Int64())

snap := e.provider.Snap()
isWorker, err := snaputil.IsWorker(snap)
Expand Down Expand Up @@ -216,6 +224,18 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo
certificates.CACert = clusterConfig.Certificates.GetCACert()
certificates.ClientCACert = clusterConfig.Certificates.GetClientCACert()

k8sdPublicKey, err := pkiutil.LoadRSAPublicKey(clusterConfig.Certificates.GetK8sdPublicKey())
if err != nil {
return response.InternalError(fmt.Errorf("failed to load k8sd public key, error: %w", err))
}

hostnames := []string{snap.Hostname()}
ips := []net.IP{net.ParseIP(s.Address().Hostname())}

extraIPs, extraNames := utils.SplitIPAndDNSSANs(req.ExtraSANs)
hostnames = append(hostnames, extraNames...)
ips = append(ips, extraIPs...)

g, ctx := errgroup.WithContext(r.Context())

for _, csr := range []struct {
Expand All @@ -234,8 +254,8 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo
commonName: fmt.Sprintf("system:node:%s", snap.Hostname()),
organization: []string{"system:nodes"},
usages: []certv1.KeyUsage{certv1.UsageDigitalSignature, certv1.UsageKeyEncipherment, certv1.UsageServerAuth},
hostnames: []string{snap.Hostname()},
ips: []net.IP{net.ParseIP(s.Address().Hostname())},
hostnames: hostnames,
ips: ips,
signerName: "k8sd.io/kubelet-serving",
certificate: &certificates.KubeletCert,
key: &certificates.KubeletKey,
Expand Down Expand Up @@ -272,14 +292,34 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo
return fmt.Errorf("failed to generate CSR for %s: %w", csr.name, err)
}

// Obtain the SHA256 sum of the CSR request.
hash := sha256.New()
_, err = hash.Write([]byte(csrPEM))
if err != nil {
return fmt.Errorf("failed to checksum CSR %s, err: %w", csr.name, err)
}

signature, err := rsa.EncryptPKCS1v15(rand.Reader, k8sdPublicKey, hash.Sum(nil))
if err != nil {
return fmt.Errorf("failed to sign CSR %s, err: %w", csr.name, err)
}
signatureB64 := base64.StdEncoding.EncodeToString(signature)

expirationSeconds := int32(req.ExpirationSeconds)

if _, err = client.CertificatesV1().CertificateSigningRequests().Create(ctx, &certv1.CertificateSigningRequest{
ObjectMeta: metav1.ObjectMeta{
Name: csr.name,
Annotations: map[string]string{
"k8sd.io/signature": signatureB64,
"k8sd.io/node": snap.Hostname(),
},
},
Spec: certv1.CertificateSigningRequestSpec{
Request: []byte(csrPEM),
Usages: csr.usages,
SignerName: csr.signerName,
Request: []byte(csrPEM),
ExpirationSeconds: &expirationSeconds,
Usages: csr.usages,
SignerName: csr.signerName,
},
}, metav1.CreateOptions{}); err != nil {
return fmt.Errorf("failed to create CSR for %s: %w", csr.name, err)
Expand Down
22 changes: 16 additions & 6 deletions src/k8s/pkg/k8sd/controllers/csrsigning/reconcile.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"time"

"github.com/canonical/k8s/pkg/utils"
pkiutil "github.com/canonical/k8s/pkg/utils/pki"
certv1 "k8s.io/api/certificates/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -96,6 +97,15 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request)
return ctrl.Result{}, err
}

notBefore := time.Now()
var notAfter time.Time

if obj.Spec.ExpirationSeconds != nil {
notAfter = utils.SecondsToExpirationDate(notBefore, int(*obj.Spec.ExpirationSeconds))
} else {
notAfter = time.Now().AddDate(10, 0, 0)
}

var crtPEM []byte
switch obj.Spec.SignerName {
case "k8sd.io/kubelet-serving":
Expand All @@ -114,8 +124,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request)
CommonName: obj.Spec.Username,
Organization: obj.Spec.Groups,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config
NotBefore: notBefore,
NotAfter: notAfter,
IPAddresses: certRequest.IPAddresses,
DNSNames: certRequest.DNSNames,
BasicConstraintsValid: true,
Expand Down Expand Up @@ -149,8 +159,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request)
CommonName: obj.Spec.Username,
Organization: obj.Spec.Groups,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
Expand Down Expand Up @@ -181,8 +191,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request)
Subject: pkix.Name{
CommonName: "system:kube-proxy",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config
NotBefore: notBefore,
NotAfter: notAfter,
BasicConstraintsValid: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
Expand Down
8 changes: 7 additions & 1 deletion src/k8s/pkg/k8sd/controllers/csrsigning/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"fmt"

"github.com/canonical/k8s/pkg/utils"
Expand All @@ -21,7 +22,12 @@ func validateCSR(obj *certv1.CertificateSigningRequest, priv *rsa.PrivateKey) er
return fmt.Errorf("failed to parse x509 certificate request: %w", err)
}

encryptedSignature := obj.Annotations["k8sd.io/signature"]
encryptedSignatureB64 := obj.Annotations["k8sd.io/signature"]
encryptedSignature, err := base64.StdEncoding.DecodeString(encryptedSignatureB64)
if err != nil {
return fmt.Errorf("failed to decode b64 signature: %w", err)
}

signature, err := rsa.DecryptPKCS1v15(nil, priv, []byte(encryptedSignature))
if err != nil {
return fmt.Errorf("failed to decrypt signature: %w", err)
Expand Down
5 changes: 3 additions & 2 deletions src/k8s/pkg/k8sd/controllers/csrsigning/validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509/pkix"
"encoding/base64"
"testing"

pkiutil "github.com/canonical/k8s/pkg/utils/pki"
Expand Down Expand Up @@ -93,7 +94,7 @@ func TestValidateCSREncryption(t *testing.T) {
},
},
expectErr: true,
expectErrMessage: "failed to decrypt signature",
expectErrMessage: "failed to decode b64 signature",
},
{
name: "Missing Signature",
Expand Down Expand Up @@ -219,5 +220,5 @@ func mustCreateEncryptedSignature(g Gomega, pub *rsa.PublicKey, csrPEM string) s
signature, err := rsa.EncryptPKCS1v15(rand.Reader, pub, hash.Sum(nil))
g.Expect(err).NotTo(HaveOccurred())

return string(signature)
return base64.StdEncoding.EncodeToString(signature)
}
1 change: 1 addition & 0 deletions tests/integration/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest==7.3.1
PyYAML==6.0.1
tenacity==8.2.3
pylint==3.2.5
cryptography==43.0.3
9 changes: 9 additions & 0 deletions tests/integration/templates/bootstrap-csr-auto-approve.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cluster-config:
network:
enabled: true
dns:
enabled: true
metrics-server:
enabled: true
annotations:
k8sd/v1alpha1/csrsigning/auto-approve: true
88 changes: 88 additions & 0 deletions tests/integration/tests/test_clustering.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
#
# Copyright 2024 Canonical, Ltd.
#
import datetime
import logging
import os
import subprocess
import tempfile
from typing import List

import pytest
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from test_util import config, harness, util

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -228,3 +234,85 @@ def test_join_with_custom_token_name(instances: List[harness.Instance]):
cluster_node.exec(["k8s", "remove-node", joining_cp_with_hostname.id])
nodes = util.ready_nodes(cluster_node)
assert len(nodes) == 1, "cp node with hostname should be removed from the cluster"


@pytest.mark.node_count(2)
@pytest.mark.bootstrap_config(
(config.MANIFESTS_DIR / "bootstrap-csr-auto-approve.yaml").read_text()
)
def test_cert_refresh(instances: List[harness.Instance]):
cluster_node = instances[0]
joining_worker = instances[1]

join_token_worker = util.get_join_token(cluster_node, joining_worker, "--worker")
util.join_cluster(joining_worker, join_token_worker)

util.wait_until_k8s_ready(cluster_node, instances)
nodes = util.ready_nodes(cluster_node)
assert len(nodes) == 2, "nodes should have joined cluster"

assert "control-plane" in util.get_local_node_status(cluster_node)
assert "worker" in util.get_local_node_status(joining_worker)

extra_san = "test_san.local"

def _check_cert(instance, cert_fname):
# Ensure that the certificate was refreshed, having the right expiry date
# and extra SAN.
cert_dir = _get_k8s_cert_dir(instance)
cert_path = os.path.join(cert_dir, cert_fname)

cert = _get_instance_cert(instance, cert_path)
date = datetime.datetime.now()
assert (cert.not_valid_after - date).days in (364, 365)

san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
san_dns_names = san.value.get_values_for_type(x509.DNSName)
assert extra_san in san_dns_names

joining_worker.exec(
["k8s", "refresh-certs", "--expires-in", "1y", "--extra-sans", extra_san]
)

_check_cert(joining_worker, "kubelet.crt")

cluster_node.exec(
["k8s", "refresh-certs", "--expires-in", "1y", "--extra-sans", extra_san]
)

_check_cert(cluster_node, "kubelet.crt")
_check_cert(cluster_node, "apiserver.crt")

# Ensure that the services come back online after refreshing the certificates.
util.wait_until_k8s_ready(cluster_node, instances)


def _get_k8s_cert_dir(instance: harness.Instance):
tested_paths = [
"/etc/kubernetes/pki/",
"/var/snap/k8s/common/etc/kubernetes/pki/",
]
for path in tested_paths:
if _instance_path_exists(instance, path):
return path

raise Exception("Could not find k8s certificates dir.")


def _instance_path_exists(instance: harness.Instance, remote_path: str):
try:
instance.exec(["ls", remote_path])
return True
except subprocess.CalledProcessError:
return False


def _get_instance_cert(
instance: harness.Instance, remote_path: str
) -> x509.Certificate:
with tempfile.NamedTemporaryFile() as fp:
instance.pull_file(remote_path, fp.name)

pem = fp.read()
cert = x509.load_pem_x509_certificate(pem, default_backend())
return cert
2 changes: 1 addition & 1 deletion tests/integration/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ passenv =
[flake8]
max-line-length = 120
select = E,W,F,C,N
ignore = W503
ignore = W503,E231,E226
exclude = venv,.git,.tox,.tox_env,.venv,build,dist,*.egg_info
show-source = true

0 comments on commit 13875ac

Please sign in to comment.