diff --git a/vdr/didx509/resolver_test.go b/vdr/didx509/resolver_test.go index 0fb808aad..b83ddc228 100644 --- a/vdr/didx509/resolver_test.go +++ b/vdr/didx509/resolver_test.go @@ -42,7 +42,8 @@ func TestManager_Resolve_OtherName(t *testing.T) { metadata := resolver2.ResolveMetadata{} otherNameValue := "A_BIG_STRING" - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue) + otherNameValueSecondary := "A_SECOND_STRING" + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue, otherNameValueSecondary}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain @@ -82,7 +83,7 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) - t.Run("happy flow, policy depth of 1", func(t *testing.T) { + t.Run("happy flow, policy depth of 1 and primary value", func(t *testing.T) { validator.EXPECT().ValidateStrict(gomock.Any()) resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) @@ -94,7 +95,21 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) - t.Run("happy flow, policy depth of 2", func(t *testing.T) { + t.Run("happy flow, policy depth of 1 and secondary value", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) + t.Run("happy flow, policy depth of 2 of type OU", func(t *testing.T) { rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::subject:OU:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, "The%20A-Team")) validator.EXPECT().ValidateStrict(gomock.Any()) @@ -108,6 +123,34 @@ func TestManager_Resolve_OtherName(t *testing.T) { didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) }) + t.Run("happy flow, policy depth of 2, primary and secondary", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) + t.Run("happy flow, policy depth of 2, secondary and primary", func(t *testing.T) { + rootDID := did.MustParseDID(fmt.Sprintf("did:x509:0:%s:%s::san:otherName:%s::san:otherName:%s", "sha256", sha256Sum(rootCertificate.Raw), otherNameValue, otherNameValueSecondary)) + + validator.EXPECT().ValidateStrict(gomock.Any()) + resolve, documentMetadata, err := resolver.Resolve(rootDID, &metadata) + + require.NoError(t, err) + assert.NotNil(t, resolve) + require.NoError(t, err) + assert.NotNil(t, documentMetadata) + // Check that the DID url is did#0 + didUrl, err := did.ParseDIDURL(rootDID.String() + "#0") + assert.NotNil(t, resolve.VerificationMethod.FindByID(*didUrl)) + }) t.Run("happy flow with only x5t header", func(t *testing.T) { delete(metadata.JwtProtectedHeaders, X509CertThumbprintS256Header) validator.EXPECT().ValidateStrict(gomock.Any()) @@ -236,7 +279,7 @@ func TestManager_Resolve_San_Generic(t *testing.T) { resolver := NewResolver(validator) metadata := resolver2.ResolveMetadata{} - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain("") + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain @@ -316,7 +359,7 @@ func TestManager_Resolve_Subject(t *testing.T) { metadata := resolver2.ResolveMetadata{} otherNameValue := "A_BIG_STRING" - _, certChain, rootCertificate, _, signingCert, err := BuildCertChain(otherNameValue) + _, certChain, rootCertificate, _, signingCert, err := BuildCertChain([]string{otherNameValue}) require.NoError(t, err) metadata.JwtProtectedHeaders = make(map[string]interface{}) metadata.JwtProtectedHeaders[X509CertChainHeader] = certChain diff --git a/vdr/didx509/validation.go b/vdr/didx509/validation.go index f70f1193c..22a3153da 100644 --- a/vdr/didx509/validation.go +++ b/vdr/didx509/validation.go @@ -144,11 +144,11 @@ type validationFunction func(cert *x509.Certificate, key string, value string) e // validatorMap maps PolicyKey to their corresponding validation functions for certificate attributes. var validatorMap = map[PolicyKey]validationFunction{ SanPolicyOtherName: func(cert *x509.Certificate, key string, value string) error { - nameValue, err := findOtherNameValue(cert) + nameValues, err := findOtherNameValues(cert) if err != nil { return err } - if nameValue != value { + if !slices.Contains(nameValues, value) { return fmt.Errorf("the SAN attribute %s does not match the query", key) } return nil diff --git a/vdr/didx509/x509_utils.go b/vdr/didx509/x509_utils.go index c6111e99f..98fe2f817 100644 --- a/vdr/didx509/x509_utils.go +++ b/vdr/didx509/x509_utils.go @@ -91,19 +91,19 @@ var ( OtherNameType = asn1.ObjectIdentifier{2, 5, 5, 5} ) -// findOtherNameValue extracts the value of a specified OtherName type from the certificate -func findOtherNameValue(cert *x509.Certificate) (string, error) { +// findOtherNameValues extracts the value of a specified OtherName types from the certificate +func findOtherNameValues(cert *x509.Certificate) ([]string, error) { for _, extension := range cert.Extensions { if extension.Id.Equal(SubjectAlternativeNameType) { - return findSanValue(extension) + return findSanValues(extension) } } - return "", nil + return make([]string, 0), nil } -// findSanValue extracts the SAN value from a given pkix.Extension, returning the resulting value or an error. -func findSanValue(extension pkix.Extension) (string, error) { - value := "" +// findSanValues extracts the SAN values from a given pkix.Extension, returning the resulting values or an error. +func findSanValues(extension pkix.Extension) ([]string, error) { + var values []string err := forEachSan(extension.Value, func(data []byte) error { var other OtherName _, err := asn1.UnmarshalWithParams(data, &other, "tag:0") @@ -111,17 +111,19 @@ func findSanValue(extension pkix.Extension) (string, error) { return err } if other.TypeID.Equal(OtherNameType) { + var value string _, err = asn1.Unmarshal(other.Value.Bytes, &value) if err != nil { return err } + values = append(values, value) } return nil }) if err != nil { - return "", err + return make([]string, 0), err } - return value, err + return values, err } // forEachSan processes each SAN extension in the certificate diff --git a/vdr/didx509/x509_utils_test.go b/vdr/didx509/x509_utils_test.go index 7c2144217..fc4de62f3 100644 --- a/vdr/didx509/x509_utils_test.go +++ b/vdr/didx509/x509_utils_test.go @@ -32,13 +32,14 @@ import ( "github.com/lestrrat-go/jwx/v2/cert" "math/big" "net" + "slices" "strings" "testing" "time" ) // BuildCertChain generates a certificate chain, including root, intermediate, and signing certificates. -func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) { +func BuildCertChain(identifiers []string) (chainCerts [4]*x509.Certificate, chain *cert.Chain, rootCertificate *x509.Certificate, signingKey *rsa.PrivateKey, signingCert *x509.Certificate, err error) { chainCerts = [4]*x509.Certificate{} chain = &cert.Chain{} rootKey, rootCert, rootPem, err := buildRootCert() @@ -68,7 +69,7 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain * return chainCerts, nil, nil, nil, nil, err } - signingKey, signingCert, signingPEM, err := buildSigningCert(identifier, intermediateL2Cert, intermediateL2Key, "32121323") + signingKey, signingCert, signingPEM, err := buildSigningCert(identifiers, intermediateL2Cert, intermediateL2Key, "32121323") if err != nil { return chainCerts, nil, nil, nil, nil, err } @@ -80,12 +81,12 @@ func BuildCertChain(identifier string) (chainCerts [4]*x509.Certificate, chain * return chainCerts, chain, rootCert, signingKey, signingCert, nil } -func buildSigningCert(identifier string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) { +func buildSigningCert(identifiers []string, intermediateL2Cert *x509.Certificate, intermediateL2Key *rsa.PrivateKey, serialNumber string) (*rsa.PrivateKey, *x509.Certificate, []byte, error) { signingKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, nil, err } - signingTmpl, err := SigningCertTemplate(nil, identifier) + signingTmpl, err := SigningCertTemplate(nil, identifiers) if err != nil { return nil, nil, nil, err } @@ -152,7 +153,7 @@ func CertTemplate(serialNumber *big.Int) (*x509.Certificate, error) { } // SigningCertTemplate creates a x509.Certificate template for a signing certificate with an optional serial number. -func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certificate, error) { +func SigningCertTemplate(serialNumber *big.Int, identifiers []string) (*x509.Certificate, error) { // generate a random serial number (a real cert authority would have some logic behind this) if serialNumber == nil { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 8) @@ -179,8 +180,8 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} // Either the ExtraExtensions SubjectAlternativeNameType is set, or the Subject Alternate Name values are set, // both don't mix - if identifier != "" { - err := setSanAlternativeName(&tmpl, identifier) + if len(identifiers) > 0 { + err := setSanAlternativeName(&tmpl, identifiers) if err != nil { return nil, err } @@ -192,27 +193,30 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif return &tmpl, nil } -func setSanAlternativeName(tmpl *x509.Certificate, identifier string) error { - raw, err := toRawValue(identifier, "ia5") - if err != nil { - return err - } - otherName := OtherName{ - TypeID: OtherNameType, - Value: asn1.RawValue{ - Class: 2, - Tag: 0, - IsCompound: true, - Bytes: raw.FullBytes, - }, - } +func setSanAlternativeName(tmpl *x509.Certificate, identifiers []string) error { + var list []asn1.RawValue - raw, err = toRawValue(otherName, "tag:0") - if err != nil { - return err + for _, identifier := range identifiers { + raw, err := toRawValue(identifier, "ia5") + if err != nil { + return err + } + otherName := OtherName{ + TypeID: OtherNameType, + Value: asn1.RawValue{ + Class: 2, + Tag: 0, + IsCompound: true, + Bytes: raw.FullBytes, + }, + } + + raw, err = toRawValue(otherName, "tag:0") + if err != nil { + return err + } + list = append(list, *raw) } - var list []asn1.RawValue - list = append(list, *raw) marshal, err := asn1.Marshal(list) if err != nil { return err @@ -261,7 +265,7 @@ func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv func TestFindOtherNameValue(t *testing.T) { t.Parallel() key, certificate, _, err := buildRootCert() - _, signingCert, _, err := buildSigningCert("123", certificate, key, "4567") + _, signingCert, _, err := buildSigningCert([]string{"123", "321"}, certificate, key, "4567") if err != nil { t.Fatalf("failed to build root certificate: %v", err) } @@ -279,22 +283,28 @@ func TestFindOtherNameValue(t *testing.T) { wantErr: false, }, { - name: "with extensions", + name: "with extensions first", cert: signingCert, want: "123", wantErr: false, }, + { + name: "with extensions second", + cert: signingCert, + want: "321", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotName, err := findOtherNameValue(tt.cert) + gotName, err := findOtherNameValues(tt.cert) if (err != nil) != tt.wantErr { - t.Errorf("findOtherNameValue() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("findOtherNameValues() error = %v, wantErr %v", err, tt.wantErr) return } - if gotName != tt.want { - t.Errorf("findOtherNameValue() = %v, want %v", gotName, tt.want) + if tt.want != "" && !slices.Contains(gotName, tt.want) { + t.Errorf("findOtherNameValues() = %v, want %v", gotName, tt.want) } }) } @@ -308,7 +318,7 @@ func TestFindCertificateByHash(t *testing.T) { } return base64.RawURLEncoding.EncodeToString(h) } - chainCerts, _, _, _, _, err := BuildCertChain("123") + chainCerts, _, _, _, _, err := BuildCertChain([]string{"123"}) if err != nil { t.Error(err) } @@ -409,7 +419,7 @@ func TestParseChain(t *testing.T) { } return &chain } - certs, chain, _, _, _, _ := BuildCertChain("123") + certs, chain, _, _, _, _ := BuildCertChain([]string{"123"}) invalidPEM := `-----BEGIN CERTIFICATE----- Y29ycnVwdCBjZXJ0aWZpY2F0ZQo= @@ -704,7 +714,7 @@ func TestFindSanValue(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - val, foundErr := findSanValue(pkix.Extension{ + val, foundErr := findSanValues(pkix.Extension{ Value: tt.rest, }) if foundErr != nil { @@ -719,15 +729,15 @@ func TestFindSanValue(t *testing.T) { t.Errorf("forEachSan() error = %v", foundErr) } if foundErr.Error() != tt.wantError.Error() { - t.Errorf("findSanValue() error = %v, want: %v", foundErr, tt.wantError) + t.Errorf("findSanValues() error = %v, want: %v", foundErr, tt.wantError) return } } } } - if val != tt.expectedValue { - t.Errorf("findSanValue() = %v, want: %v", val, tt.expectedValue) + if tt.expectedValue != "" && !slices.Contains(val, tt.expectedValue) { + t.Errorf("findSanValues() = %v, want: %v", val, tt.expectedValue) } })