diff --git a/pki/interface.go b/pki/interface.go index 95fd2c026..8b3631309 100644 --- a/pki/interface.go +++ b/pki/interface.go @@ -63,9 +63,13 @@ type Validator interface { // ErrCRLMissing and ErrCRLExpired signal that at least one of the certificates cannot be validated reliably. // If the certificate was revoked on an expired CRL, it wil return ErrCertRevoked. // Ignoring all errors except ErrCertRevoked changes the behavior from hard-fail to soft-fail. Without a truststore, the Validator is a noop if set to soft-fail + // Validate uses the configured soft-/hard-fail strategy // The certificate chain is expected to be sorted leaf to root. Validate(chain []*x509.Certificate) error + // ValidateStrict does the same as Validate, except it always uses the hard-fail strategy. + ValidateStrict(chain []*x509.Certificate) error + // SetVerifyPeerCertificateFunc sets config.ValidatePeerCertificate to use Validate. SetVerifyPeerCertificateFunc(config *tls.Config) error diff --git a/pki/mock.go b/pki/mock.go index d8f2321aa..21d8c97a3 100644 --- a/pki/mock.go +++ b/pki/mock.go @@ -189,6 +189,20 @@ func (mr *MockValidatorMockRecorder) Validate(chain any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockValidator)(nil).Validate), chain) } +// ValidateStrict mocks base method. +func (m *MockValidator) ValidateStrict(chain []*x509.Certificate) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateStrict", chain) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateStrict indicates an expected call of ValidateStrict. +func (mr *MockValidatorMockRecorder) ValidateStrict(chain any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateStrict", reflect.TypeOf((*MockValidator)(nil).ValidateStrict), chain) +} + // MockProvider is a mock of Provider interface. type MockProvider struct { ctrl *gomock.Controller @@ -281,3 +295,17 @@ func (mr *MockProviderMockRecorder) Validate(chain any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockProvider)(nil).Validate), chain) } + +// ValidateStrict mocks base method. +func (m *MockProvider) ValidateStrict(chain []*x509.Certificate) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateStrict", chain) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateStrict indicates an expected call of ValidateStrict. +func (mr *MockProviderMockRecorder) ValidateStrict(chain any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateStrict", reflect.TypeOf((*MockProvider)(nil).ValidateStrict), chain) +} diff --git a/pki/validator.go b/pki/validator.go index e8e70dbd0..ad1a9da3f 100644 --- a/pki/validator.go +++ b/pki/validator.go @@ -128,6 +128,14 @@ func (v *validator) syncLoop(ctx context.Context) { } func (v *validator) Validate(chain []*x509.Certificate) error { + return v.validate(chain, v.softfail) +} + +func (v *validator) ValidateStrict(chain []*x509.Certificate) error { + return v.validate(chain, false) +} + +func (v *validator) validate(chain []*x509.Certificate, softfail bool) error { var cert *x509.Certificate var err error for i := range chain { @@ -135,7 +143,7 @@ func (v *validator) Validate(chain []*x509.Certificate) error { // check in reverse order to prevent CRL expiration errors due to revoked CAs no longer issuing CRLs if err = v.validateCert(cert); err != nil { errOut := fmt.Errorf("%w: subject=%s, S/N=%s, issuer=%s", err, cert.Subject.String(), cert.SerialNumber.String(), cert.Issuer.String()) - if v.softfail && !(errors.Is(err, ErrCertRevoked) || errors.Is(err, ErrCertBanned)) { + if softfail && !(errors.Is(err, ErrCertRevoked) || errors.Is(err, ErrCertBanned)) { // Accept the certificate even if it cannot be properly validated logger().WithError(errOut).Error("Certificate CRL check softfail bypass. Might be unsafe, find cause of failure!") continue diff --git a/pki/validator_test.go b/pki/validator_test.go index 081275afd..2e8e1f6bf 100644 --- a/pki/validator_test.go +++ b/pki/validator_test.go @@ -110,11 +110,21 @@ func TestValidator_Validate(t *testing.T) { assert.ErrorIs(t, err, expected) } } + fnStrict := func(expected error) { + val.softfail = true // make sure it ignores the configured value + err = val.ValidateStrict([]*x509.Certificate{cert}) + if expected == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, expected) + } + } t.Run("softfail", func(t *testing.T) { fn(true, softfailReturn) }) t.Run("hardfail", func(t *testing.T) { fn(false, hardfailReturn) + fnStrict(hardfailReturn) }) }