diff --git a/src/k8s/pkg/k8sd/pki/control_plane.go b/src/k8s/pkg/k8sd/pki/control_plane.go index 13687cf1c..fae52a8b5 100644 --- a/src/k8s/pkg/k8sd/pki/control_plane.go +++ b/src/k8s/pkg/k8sd/pki/control_plane.go @@ -118,6 +118,11 @@ func (c *ControlPlanePKI) CompleteCertificates() error { } c.CACert = cert c.CAKey = key + } else { + certCheck := pkiutil.CertCheck{CN: "kubernetes-ca", AllowSelfSigned: true} + if err := certCheck.ValidateKeypair(c.CACert, c.CAKey); err != nil { + return fmt.Errorf("kubernetes CA certificate validation failure: %w", err) + } } // Generate self-signed client CA (if not set already) @@ -131,6 +136,11 @@ func (c *ControlPlanePKI) CompleteCertificates() error { } c.ClientCACert = cert c.ClientCAKey = key + } else { + certCheck := pkiutil.CertCheck{CN: "kubernetes-ca-client", AllowSelfSigned: true} + if err := certCheck.ValidateKeypair(c.ClientCACert, c.ClientCAKey); err != nil { + return fmt.Errorf("kubernetes client CA certificate validation failure: %w", err) + } } serverCACert, serverCAKey, err := pkiutil.LoadCertificate(c.CACert, c.CAKey) @@ -154,6 +164,11 @@ func (c *ControlPlanePKI) CompleteCertificates() error { } c.FrontProxyCACert = cert c.FrontProxyCAKey = key + } else { + certCheck := pkiutil.CertCheck{CN: "front-proxy-ca", AllowSelfSigned: true} + if err := certCheck.ValidateKeypair(c.FrontProxyCACert, c.FrontProxyCAKey); err != nil { + return fmt.Errorf("kubernetes front-proxy CA certificate validation failure: %w", err) + } } // Generate front proxy client certificate (ok to override) @@ -177,6 +192,11 @@ func (c *ControlPlanePKI) CompleteCertificates() error { c.FrontProxyClientCert = cert c.FrontProxyClientKey = key + } else { + certCheck := pkiutil.CertCheck{CN: "front-proxy-client", CaPEM: c.FrontProxyCACert} + if err := certCheck.ValidateKeypair(c.FrontProxyClientCert, c.FrontProxyClientKey); err != nil { + return fmt.Errorf("kubernetes front-proxy client certificate validation failure: %w", err) + } } // Generate service account key (if missing) @@ -213,6 +233,16 @@ func (c *ControlPlanePKI) CompleteCertificates() error { c.KubeletCert = cert c.KubeletKey = key + } else { + certCheck := pkiutil.CertCheck{ + CN: fmt.Sprintf("system:node:%s", c.hostname), + O: []string{"system:nodes"}, + CaPEM: c.CACert, + DNSSANs: []string{c.hostname}, + } + if err := certCheck.ValidateKeypair(c.KubeletCert, c.KubeletKey); err != nil { + return fmt.Errorf("kubelet certificate validation failure: %w", err) + } } // Generate apiserver-kubelet-client certificate (if missing) @@ -232,6 +262,15 @@ func (c *ControlPlanePKI) CompleteCertificates() error { c.APIServerKubeletClientCert = cert c.APIServerKubeletClientKey = key + } else { + certCheck := pkiutil.CertCheck{ + CN: fmt.Sprintf("apiserver-kubelet-client"), + O: []string{"system:masters"}, + CaPEM: c.ClientCACert, + } + if err := certCheck.ValidateKeypair(c.APIServerKubeletClientCert, c.APIServerKubeletClientKey); err != nil { + return fmt.Errorf("apiserver-kubelet-client certificate validation failure: %w", err) + } } // Generate kube-apiserver certificate (if missing) @@ -256,6 +295,15 @@ func (c *ControlPlanePKI) CompleteCertificates() error { c.APIServerCert = cert c.APIServerKey = key + } else { + certCheck := pkiutil.CertCheck{ + CN: fmt.Sprintf("kube-apiserver"), + CaPEM: c.CACert, + DNSSANs: []string{"kubernetes", "kubernetes.default", "kubernetes.default.svc", "kubernetes.default.svc.cluster", "kubernetes.default.svc.cluster.local"}, + } + if err := certCheck.ValidateKeypair(c.APIServerCert, c.APIServerKey); err != nil { + return fmt.Errorf("kube-apiservert certificate validation failure: %w", err) + } } for _, i := range []struct { diff --git a/src/k8s/pkg/k8sd/pki/control_plane_test.go b/src/k8s/pkg/k8sd/pki/control_plane_test.go index b9ae24f8b..6a635f67f 100644 --- a/src/k8s/pkg/k8sd/pki/control_plane_test.go +++ b/src/k8s/pkg/k8sd/pki/control_plane_test.go @@ -6,7 +6,9 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" + "crypto/x509/pkix" "encoding/pem" + "fmt" "net" "os" "testing" @@ -25,14 +27,60 @@ func mustReadTestData(t *testing.T, filename string) string { return string(data) } +// patchCertPEM can be used to modify certificates for testing purposes. +func patchCertPEM( + certPEM string, + caPEM string, + caKeyPEM string, + updateFunc func(*x509.Certificate) error, +) (string, string, error) { + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return "", "", fmt.Errorf("failed to decode certificate") + } + + cert, _ := x509.ParseCertificate(block.Bytes) + if cert == nil { + return "", "", fmt.Errorf("failed to decode certificate") + } + + // Generate a new certificate based on the input certificate and the + // updates applied by "updateFunc". + template, err := pkiutil.GenerateCertificate( + cert.Subject, + cert.NotBefore, cert.NotAfter, false, + cert.DNSNames, cert.IPAddresses, + ) + if err != nil { + return "", "", fmt.Errorf("failed to generate patched certificate") + } + + if err = updateFunc(template); err != nil { + return "", "", fmt.Errorf("cert update failed: %w", err) + } + + caCert, caKey, err := pkiutil.LoadCertificate(caPEM, caKeyPEM) + if err != nil { + return "", "", fmt.Errorf("failed to load CA cert: %w", err) + } + + certPem, keyPem, err := pkiutil.SignCertificate(template, 2048, caCert, &caCert.PublicKey, caKey) + if err != nil { + return "", "", fmt.Errorf("failed to sign cert: %w", err) + } + + return certPem, keyPem, err +} + func TestControlPlaneCertificates(t *testing.T) { notBefore := time.Now() - c := pki.NewControlPlanePKI(pki.ControlPlanePKIOpts{ + opts := pki.ControlPlanePKIOpts{ Hostname: "h1", NotBefore: notBefore, NotAfter: notBefore.AddDate(1, 0, 0), AllowSelfSignedCA: true, - }) + } + c := pki.NewControlPlanePKI(opts) g := NewWithT(t) @@ -142,4 +190,101 @@ func TestControlPlaneCertificates(t *testing.T) { g.Expect(cert.DNSNames).To(ConsistOf(expectedDNSNames)) }) }) + + t.Run("InvalidSan", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + // Switch CA certificates, expecting certificate validation failures. + c.CACert = c.FrontProxyCACert + c.CAKey = c.FrontProxyCAKey + g.Expect(c.CompleteCertificates()).ToNot(Succeed()) + }) + + t.Run("KubeletCertExpired", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + var err error + c.KubeletCert, c.KubeletKey, err = patchCertPEM(c.KubeletCert, c.CACert, c.CAKey, func(cert *x509.Certificate) error { + cert.NotAfter = time.Now().AddDate(-1, 0, 0) + return nil + }) + g.Expect(err).ToNot(HaveOccurred()) + + err = c.CompleteCertificates() + g.Expect(err).To(MatchError(ContainSubstring("certificate expired"))) + }) + + t.Run("KubeletCertNotBefore", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + var err error + c.KubeletCert, c.KubeletKey, err = patchCertPEM(c.KubeletCert, c.CACert, c.CAKey, func(cert *x509.Certificate) error { + cert.NotBefore = time.Now().AddDate(1, 0, 0) + return nil + }) + g.Expect(err).ToNot(HaveOccurred()) + + err = c.CompleteCertificates() + g.Expect(err).To(MatchError(ContainSubstring("invalid certificate, not valid before"))) + }) + + t.Run("KubeletCertInvalidCN", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + var err error + c.KubeletCert, c.KubeletKey, err = patchCertPEM(c.KubeletCert, c.CACert, c.CAKey, func(cert *x509.Certificate) error { + cert.Subject = pkix.Name{ + CommonName: "unexpected-cn", + Organization: cert.Subject.Organization, + } + return nil + }) + g.Expect(err).ToNot(HaveOccurred()) + + err = c.CompleteCertificates() + g.Expect(err).To(MatchError(ContainSubstring("invalid certificate CN"))) + }) + + t.Run("KubeletCertInvalidOrganization", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + var err error + c.KubeletCert, c.KubeletKey, err = patchCertPEM(c.KubeletCert, c.CACert, c.CAKey, func(cert *x509.Certificate) error { + cert.Subject = pkix.Name{ + CommonName: cert.Subject.CommonName, + Organization: []string{"unexpected-organization"}, + } + return nil + }) + g.Expect(err).ToNot(HaveOccurred()) + + err = c.CompleteCertificates() + g.Expect(err).To(MatchError(ContainSubstring("missing cert organization"))) + }) + + t.Run("KubeletCertInvalidDNSName", func(t *testing.T) { + c := pki.NewControlPlanePKI(opts) + g := NewWithT(t) + g.Expect(c.CompleteCertificates()).To(Succeed()) + + var err error + c.KubeletCert, c.KubeletKey, err = patchCertPEM(c.KubeletCert, c.CACert, c.CAKey, func(cert *x509.Certificate) error { + cert.DNSNames = []string{"some-other-dnsname"} + return nil + }) + g.Expect(err).ToNot(HaveOccurred()) + + err = c.CompleteCertificates() + g.Expect(err).To(MatchError(MatchRegexp(`certificate dns name \(.*\) validation failure`))) + }) } diff --git a/src/k8s/pkg/utils/pki/validate.go b/src/k8s/pkg/utils/pki/validate.go new file mode 100755 index 000000000..0fb77a3c1 --- /dev/null +++ b/src/k8s/pkg/utils/pki/validate.go @@ -0,0 +1,86 @@ +package pkiutil + +import ( + "crypto/x509" + "fmt" + "slices" + "time" +) + +// CertCheck can be used to validate certificates. Unspecified fields are +// ignored. "NotBefore" and "NotAfter" are checked implicitly. +type CertCheck struct { + // Ensure that the certificate has the specified Common Name. + CN string + // Ensure that the certificate contains the following organizations. + O []string + // Ensure that the certificate contains the following DNS SANs. + DNSSANs []string + // Ensure that the certificate contains the following IP SANs. + IPSANs []string + // Validate the certificate against the specified CA certificate. + CaPEM string + AllowSelfSigned bool +} + +func (check CertCheck) ValidateKeypair(certPEM string, keyPEM string) error { + cert, _, err := LoadCertificate(certPEM, keyPEM) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + + return check.ValidateCert(cert) +} + +func (check CertCheck) ValidateCert(cert *x509.Certificate) error { + if check.CN != "" && check.CN != cert.Subject.CommonName { + return fmt.Errorf("invalid certificate CN, expected: %s, actual: %s ", + check.CN, cert.Subject.CommonName) + } + for _, checkO := range check.O { + if !slices.Contains(cert.Subject.Organization, checkO) { + return fmt.Errorf("missing cert organization: %s, actual: %v", + checkO, cert.Subject.Organization) + } + } + + now := time.Now() + if now.Before(cert.NotBefore) { + return fmt.Errorf("invalid certificate, not valid before: %v, current time: %v", + cert.NotBefore, now) + } + if now.After(cert.NotAfter) { + return fmt.Errorf("certificate expired since: %v, current time: %v", + cert.NotAfter, now) + } + + if !check.AllowSelfSigned { + verifyOpts := x509.VerifyOptions{} + if check.CaPEM != "" { + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM([]byte(check.CaPEM)) { + return fmt.Errorf("invalid CA certificate") + } + verifyOpts.Roots = roots + } + + if _, err := cert.Verify(verifyOpts); err != nil { + return fmt.Errorf("certificate validation failure: %w", err) + } + } + + for _, dnsName := range check.DNSSANs { + if err := cert.VerifyHostname(dnsName); err != nil { + return fmt.Errorf("certificate dns name (%s) validation failure: %w, allowed dns names: %v", + dnsName, err, cert.DNSNames) + } + } + for _, ip := range check.IPSANs { + if err := cert.VerifyHostname("[" + ip + "]"); err != nil { + return fmt.Errorf("certificate ip (%s) validation failure: %w, allowed IPs: %v", + ip, err, cert.IPAddresses) + } + } + + return nil +}