diff --git a/libs/go/sia/Makefile b/libs/go/sia/Makefile index 0b484298254..1b692cc5283 100644 --- a/libs/go/sia/Makefile +++ b/libs/go/sia/Makefile @@ -12,7 +12,7 @@ ifneq ($(patsubst %$(SIA_DIR),,$(lastword $(ATHENZ_DIR))),) ATHENZ_DIR = $(PWD)/$(SIA_DIR) endif -SUBDIRS = access/config access/tokens agent aws/agent aws/attestation aws/doc aws/lambda aws/meta \ +SUBDIRS = access/config access/tokens agent aws/attestation aws/doc aws/lambda aws/meta \ aws/options aws/stssession file futil gcp/attestation gcp/meta gcp/functions \ host/hostdoc host/ip host/provider host/signature host/utils logutil options pki/cert \ sds ssh/hostcert ssh/hostkey util verify diff --git a/libs/go/sia/agent/agent.go b/libs/go/sia/agent/agent.go index d5561d9ebf5..a5eb842b0da 100644 --- a/libs/go/sia/agent/agent.go +++ b/libs/go/sia/agent/agent.go @@ -36,6 +36,7 @@ import ( "github.com/AthenZ/athenz/libs/go/sia/access/config" "github.com/AthenZ/athenz/libs/go/sia/access/tokens" sc "github.com/AthenZ/athenz/libs/go/sia/config" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/options" "github.com/AthenZ/athenz/libs/go/sia/sds" "github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey" @@ -278,7 +279,7 @@ func registerSvc(svc sc.Service, ztsUrl string, opts *sc.Options) error { if err != nil { return err } - attestData, err := opts.Provider.CloudAttestationData(opts.MetaEndPoint, svc.Name, ztsUrl) + attestData, err := opts.Provider.CloudAttestationData(setUpAttestationRequest(opts, svc.Name, ztsUrl)) if err != nil { log.Printf("Failed to get attestation data to prove the identity, err:%v\n", err) return err @@ -350,6 +351,21 @@ func registerSvc(svc sc.Service, ztsUrl string, opts *sc.Options) error { return nil } +func setUpAttestationRequest(opts *sc.Options, service, ztsUrl string) *provider.AttestationRequest { + return &provider.AttestationRequest{ + MetaEndPoint: opts.MetaEndPoint, + Domain: opts.Domain, + Service: service, + ZTSUrl: ztsUrl, + Account: opts.Account, + Region: opts.Region, + OmitDomain: opts.OmitDomain, + UseRegionalSTS: opts.UseRegionalSTS, + EC2Document: opts.EC2Document, + EC2Signature: opts.EC2Signature, + } +} + func refreshSvc(svc sc.Service, ztsUrl string, opts *sc.Options) error { keyFile := util.GetSvcKeyFileName(opts.KeyDir, svc.KeyFilename, opts.Domain, svc.Name) @@ -399,7 +415,7 @@ func refreshSvc(svc sc.Service, ztsUrl string, opts *sc.Options) error { return err } - attestData, err := opts.Provider.CloudAttestationData(opts.MetaEndPoint, svc.Name, ztsUrl) + attestData, err := opts.Provider.CloudAttestationData(setUpAttestationRequest(opts, svc.Name, ztsUrl)) if err != nil { log.Printf("Failed to get attestation data to prove the identity, err:%v\n", err) return err @@ -637,10 +653,6 @@ func runAgentCommand(siaCmd, ztsUrl string, opts *sc.Options) { //server and role certs are valid for 30 days by default rotationInterval := time.Duration(opts.RefreshInterval) * time.Minute - //data, err := opts.Provider.CloudAttestationData(opts) - //if err != nil { - // log.Fatalf("Cannot determine identity to run as, err:%v\n", err) - //} svcs := options.GetSvcNames(opts.Services) tokenOpts, err := tokenOptions(opts, ztsUrl) diff --git a/libs/go/sia/agent/agent_test.go b/libs/go/sia/agent/agent_test.go index a54b85ee4d3..4578322231c 100644 --- a/libs/go/sia/agent/agent_test.go +++ b/libs/go/sia/agent/agent_test.go @@ -28,16 +28,16 @@ import ( "testing" "time" - "k8s.io/utils/strings/slices" - "github.com/AthenZ/athenz/libs/go/sia/access/config" "github.com/AthenZ/athenz/libs/go/sia/agent/devel/ztsmock" sc "github.com/AthenZ/athenz/libs/go/sia/config" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" "github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey" "github.com/AthenZ/athenz/libs/go/sia/util" "github.com/stretchr/testify/assert" + "k8s.io/utils/strings/slices" ) func setup() { @@ -105,7 +105,7 @@ func (tp TestProvider) GetSuffixes() []string { return []string{} } -func (tp TestProvider) CloudAttestationData(string, string, string) (string, error) { +func (tp TestProvider) CloudAttestationData(*provider.AttestationRequest) (string, error) { return "abc", nil } @@ -211,7 +211,7 @@ func TestRegisterInstance(test *testing.T) { KeyDir: siaDir, CertDir: siaDir, AthenzCACertFile: caCertFile, - ZTSAWSDomains: []string{"zts-aws-cloud"}, + ZTSCloudDomains: []string{"zts-aws-cloud"}, Region: "us-west-2", InstanceId: "pod-1234", Provider: tp, @@ -288,7 +288,7 @@ func refreshServiceCertSetup(test *testing.T) (*sc.Options, string) { CertDir: siaDir, AthenzCACertFile: caCertFile, Provider: tp, - ZTSAWSDomains: []string{"zts-aws-cloud"}, + ZTSCloudDomains: []string{"zts-aws-cloud"}, Region: "us-west-2", InstanceId: "pod-1234", } @@ -366,7 +366,7 @@ func TestRoleCertificateRequest(test *testing.T) { KeyDir: siaDir, CertDir: siaDir, AthenzCACertFile: caCertFile, - ZTSAWSDomains: []string{"zts-aws-cloud"}, + ZTSCloudDomains: []string{"zts-aws-cloud"}, Provider: tp, } @@ -621,7 +621,7 @@ func TestGenerateSshRequest(test *testing.T) { // ssh enabled with primary service and key type is rsa - null cert request but valid csr opts.SshPubKeyFile = "devel/data/cert.pem" opts.Domain = "athenz" - opts.ZTSAWSDomains = []string{"athenz.io"} + opts.ZTSCloudDomains = []string{"athenz.io"} opts.SshHostKeyType = hostkey.Rsa sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io") assert.Nil(test, sshReq) @@ -631,9 +631,10 @@ func TestGenerateSshRequest(test *testing.T) { opts.SshHostKeyType = hostkey.Ecdsa sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io") assert.NotNil(test, sshReq) - assert.Equal(test, 3, len(sshReq.CertRequestData.Principals)) + assert.Equal(test, 4, len(sshReq.CertRequestData.Principals)) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-vm")) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-instance-id")) + assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "api.athenz.athenz.io")) assert.Empty(test, sshCsr) assert.Nil(test, err) // ssh enabled with primary service and key type is ecdsa - empty csr but not-nil cert request, opts defines sshPrincipals @@ -641,11 +642,12 @@ func TestGenerateSshRequest(test *testing.T) { opts.SshPrincipals = "cname.athenz.io" sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io") assert.NotNil(test, sshReq) - assert.Equal(test, 4, len(sshReq.CertRequestData.Principals)) + assert.Equal(test, 5, len(sshReq.CertRequestData.Principals)) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "hostname.athenz.io")) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "cname.athenz.io")) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-vm")) assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "my-instance-id")) + assert.True(test, slices.Contains(sshReq.CertRequestData.Principals, "api.athenz.athenz.io")) assert.Empty(test, sshCsr) assert.Nil(test, err) } diff --git a/libs/go/sia/aws/agent/agent.go b/libs/go/sia/aws/agent/agent.go deleted file mode 100644 index 2ab5bb7a5f7..00000000000 --- a/libs/go/sia/aws/agent/agent.go +++ /dev/null @@ -1,964 +0,0 @@ -// -// Copyright The Athenz Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package agent - -import ( - "bufio" - "crypto/rsa" - "crypto/x509" - "encoding/json" - "encoding/pem" - "fmt" - "log" - "math" - "os" - "os/exec" - "os/signal" - "strings" - "syscall" - "time" - - "github.com/AthenZ/athenz/clients/go/zts" - "github.com/AthenZ/athenz/libs/go/athenzutils" - "github.com/AthenZ/athenz/libs/go/sia/access/config" - "github.com/AthenZ/athenz/libs/go/sia/access/tokens" - "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" - "github.com/AthenZ/athenz/libs/go/sia/aws/options" - sc "github.com/AthenZ/athenz/libs/go/sia/config" - "github.com/AthenZ/athenz/libs/go/sia/sds" - "github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey" - "github.com/AthenZ/athenz/libs/go/sia/util" - "github.com/ardielle/ardielle-go/rdl" - "github.com/cenkalti/backoff" -) - -const siaMainDir = "/var/lib/sia" - -func readCertificate(certFile string) (*x509.Certificate, error) { - data, err := os.ReadFile(certFile) - if err != nil { - return nil, err - } - var block *pem.Block - block, _ = pem.Decode(data) - if block == nil { - return nil, nil - } - return x509.ParseCertificate(block.Bytes) -} - -func GetPrevRoleCertDates(certFile string) (*rdl.Timestamp, *rdl.Timestamp, error) { - prevRolCert, err := readCertificate(certFile) - if err != nil { - return nil, nil, err - } - - notBefore := &rdl.Timestamp{ - Time: prevRolCert.NotBefore, - } - - notAfter := &rdl.Timestamp{ - Time: prevRolCert.NotAfter, - } - - log.Printf("Existing role cert %s, not before: %s, not after: %s\n", certFile, notBefore.String(), notAfter.String()) - return notBefore, notAfter, nil -} - -func RoleKey(rotateKey bool, roleKey, svcKey string) (*rsa.PrivateKey, error) { - if rotateKey { - return util.GenerateKeyPair(2048) - } else if roleKey != "" && util.FileExists(roleKey) { - return util.PrivateKeyFromFile(roleKey) - } else { - return util.PrivateKeyFromFile(svcKey) - } -} - -func GetRoleCertificates(ztsUrl string, opts *sc.Options) (int, []string) { - - //initialize our return state to success - failures := make([]string, 0) - - for _, role := range opts.Roles { - var roleRequest = new(zts.RoleCertificateRequest) - - svcKeyFile := util.GetSvcKeyFileName(opts.KeyDir, role.SvcKeyFilename, opts.Domain, role.Service) - svcCertFile := util.GetSvcCertFileName(opts.CertDir, role.SvcCertFilename, opts.Domain, role.Service) - - client, err := util.ZtsClient(ztsUrl, opts.ZTSServerName, svcKeyFile, svcCertFile, opts.ZTSCACertFile) - if err != nil { - log.Printf("unable to initialize ZTS Client with url %s for role %s, err: %v\n", ztsUrl, role.Name, err) - failures = append(failures, role.Name) - continue - } - client.AddCredentials("User-Agent", opts.Version) - - var key *rsa.PrivateKey - if opts.GenerateRoleKey { - key, err = RoleKey(opts.RotateKey, role.RoleKeyFilename, svcKeyFile) - } else { - key, err = util.PrivateKeyFromFile(svcKeyFile) - } - if err != nil { - log.Printf("unable to read private key role %s, err: %v\n", role.Name, err) - failures = append(failures, role.Name) - continue - } - - emailDomain := "" - if opts.RolePrincipalEmail { - emailDomain = opts.ZTSAWSDomains[0] - } - roleCertReqOptions := &util.RoleCertReqOptions{ - Country: opts.CertCountryName, - OrgName: opts.CertOrgName, - Domain: opts.Domain, - Service: role.Service, - RoleName: role.Name, - InstanceId: opts.InstanceId, - Provider: opts.Provider.GetName(), - EmailDomain: emailDomain, - SpiffeTrustDomain: opts.SpiffeTrustDomain, - } - csr, err := util.GenerateRoleCertCSR(key, roleCertReqOptions) - if err != nil { - log.Printf("unable to generate CSR for %s, err: %v\n", role.Name, err) - failures = append(failures, role.Name) - continue - } - roleRequest.Csr = csr - if role.ExpiryTime > 0 { - roleRequest.ExpiryTime = int64(role.ExpiryTime) - } - - notBefore, notAfter, _ := GetPrevRoleCertDates(role.RoleCertFilename) - roleRequest.PrevCertNotBefore = notBefore - roleRequest.PrevCertNotAfter = notAfter - if notBefore != nil && notAfter != nil { - log.Printf("Previous Role Cert Not Before date: %s, Not After date: %s\n", notBefore, notAfter) - } - - //"rolename": "athenz.fp:role.readers" - //from the rolename, domain is athenz.fp - //role is readers - roleCert, err := client.PostRoleCertificateRequestExt(roleRequest) - if err != nil { - log.Printf("PostRoleCertificateRequest failed for %s, err: %v\n", role.Name, err) - failures = append(failures, role.Name) - continue - } - - roleKeyBytes := util.PrivatePem(key) - err = util.SaveRoleCertKey([]byte(roleKeyBytes), []byte(roleCert.X509Certificate), role.RoleKeyFilename, role.RoleCertFilename, svcKeyFile, role.Name, role.Uid, role.Gid, role.FileMode, opts.GenerateRoleKey, opts.RotateKey, opts.BackupDir, opts.FileDirectUpdate) - if err != nil { - failures = append(failures, role.Name) - continue - } - } - log.Printf("SIA processed %d (failures %d) role certificate requests\n", len(opts.Roles), len(failures)) - return len(opts.Roles), failures -} - -func RegisterInstance(data []*attestation.AttestationData, ztsUrl string, opts *sc.Options, docExpiryCheck bool) error { - - //special handling for EC2 instances - //before we process our register event we need to check to - //see if our timestamp in our document is less than 30 mins - //ago otherwise ZTS server will reject the request and there - //is no point of processing the request - if docExpiryCheck && shouldSkipRegister(opts) { - return fmt.Errorf("identity document has expired (30 min timeout). ZTS will not register this instance. Please relaunch or stop and start your instance to refesh its identity document") - } - - for i, svc := range opts.Services { - err := registerSvc(svc, data[i], ztsUrl, opts) - if err != nil { - return fmt.Errorf("unable to register identity for svc: %q, error: %v", svc.Name, err) - } - } - return nil -} - -func RefreshInstance(data []*attestation.AttestationData, ztsUrl string, opts *sc.Options) error { - for i, svc := range opts.Services { - err := refreshSvc(svc, data[i], ztsUrl, opts) - if err != nil { - return fmt.Errorf("unable to refresh identity for svc: %q, error: %v", svc.Name, err) - } - } - return nil -} - -func getServiceHostname(opts *sc.Options, svc sc.Service, fqdn bool) string { - if !opts.SanDnsHostname { - return "" - } - hostname := opts.Provider.GetHostname(fqdn) - if hostname == "" { - log.Println("No hostname configured for the instance") - return "" - } - //if the hostname contains multiple components then we'll - //return our hostname as is - if strings.Contains(hostname, ".") { - return hostname - } - //otherwise, we'll generate one based on the format - //... only if the - //suffix is properly configured since we might be having - //multiple suffix values - if opts.HostnameSuffix == "" { - // if our initial request was without fqdn then we're - // going to retry with the fqdn otherwise we'll just - // return an empty string - if fqdn { - log.Printf("No hostname suffix configured for the instance: %s\n", hostname) - return "" - } else { - return getServiceHostname(opts, svc, true) - } - } - - hyphenDomain := strings.Replace(opts.Domain, ".", "-", -1) - return fmt.Sprintf("%s.%s.%s.%s", hostname, svc.Name, hyphenDomain, opts.HostnameSuffix) -} - -func registerSvc(svc sc.Service, data *attestation.AttestationData, ztsUrl string, opts *sc.Options) error { - - key, err := util.GenerateKeyPair(2048) - if err != nil { - return err - } - - //if ssh support is enabled then we need to generate the csr - //it is also generated for the primary service only - hostname := getServiceHostname(opts, svc, false) - sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname) - if err != nil { - return err - } - - //if the user hasn't configured to include the san dns hostname - //then we're going to reset the hostname value to an empty string - if !opts.SanDnsHostname { - hostname = "" - } - svcCertReqOptions := &util.SvcCertReqOptions{ - Country: opts.CertCountryName, - OrgName: opts.CertOrgName, - Domain: opts.Domain, - Service: svc.Name, - CommonName: data.CommonName, - InstanceId: opts.InstanceId, - Provider: opts.Provider.GetName(), - Hostname: hostname, - SpiffeTrustDomain: opts.SpiffeTrustDomain, - SpiffeNamespace: opts.SpiffeNamespace, - AddlSanDNSEntries: opts.AddlSanDNSEntries, - ZtsDomains: opts.ZTSAWSDomains, - WildCardDnsName: opts.SanDnsWildcard, - InstanceIdSanDNS: opts.InstanceIdSanDNS, - } - csr, err := util.GenerateSvcCertCSR(key, svcCertReqOptions) - if err != nil { - return err - } - attestData, err := json.Marshal(data) - if err != nil { - return err - } - - athenzJwk := true - athenzJwkModified := util.GetAthenzJwkConfModTime(siaMainDir) - - info := &zts.InstanceRegisterInformation{ - Provider: zts.ServiceName(opts.Provider.GetName()), - Domain: zts.DomainName(opts.Domain), - Service: zts.SimpleName(svc.Name), - Csr: csr, - Ssh: sshCsr, - SshCertRequest: sshCertRequest, - AttestationData: string(attestData), - AthenzJWK: &athenzJwk, - AthenzJWKModified: &athenzJwkModified, - Hostname: zts.DomainName(hostname), - Namespace: zts.SimpleName(opts.SpiffeNamespace), - } - if svc.ExpiryTime > 0 && svc.ExpiryTime <= math.MaxInt32 { - expiryTime := int32(svc.ExpiryTime) - info.ExpiryTime = &expiryTime - } - - client, err := util.ZtsClient(ztsUrl, opts.ZTSServerName, "", "", opts.ZTSCACertFile) - if err != nil { - return err - } - client.AddCredentials("User-Agent", opts.Version) - - ident, _, err := client.PostInstanceRegisterInformation(info) - if err != nil { - log.Printf("Unable to do PostInstanceRegisterInformation, err: %v\n", err) - return err - } - svcKeyFile := util.GetSvcKeyFileName(opts.KeyDir, svc.KeyFilename, opts.Domain, svc.Name) - err = util.UpdateFile(svcKeyFile, []byte(util.PrivatePem(key)), svc.Uid, svc.Gid, 0440, opts.FileDirectUpdate, true) - if err != nil { - return err - } - svcCertFile := util.GetSvcCertFileName(opts.CertDir, svc.CertFilename, opts.Domain, svc.Name) - err = util.UpdateFile(svcCertFile, []byte(ident.X509Certificate), svc.Uid, svc.Gid, 0444, opts.FileDirectUpdate, true) - if err != nil { - return err - } - - if opts.Services[0].Name == svc.Name { - err = util.UpdateFile(opts.AthenzCACertFile, []byte(ident.X509CertificateSigner), svc.Uid, svc.Gid, 0444, opts.FileDirectUpdate, true) - if err != nil { - return err - } - } - //we're not going to count ssh updates as fatal since the primary - //task for sia to get service identity certs but we'll log the failure - if ident.SshCertificate != "" { - err = updateSSH(opts.SshCertFile, opts.SshConfigFile, ident.SshCertificate, opts.FileDirectUpdate) - if err != nil { - log.Printf("Unable to update ssh certificate, err: %v\n", err) - } - } - - if ident.AthenzJWK != nil { - err = util.WriteAthenzJWKFile(ident.AthenzJWK, siaMainDir, svc.Uid, svc.Gid) - if err != nil { - return err - } - } - return nil -} - -func refreshSvc(svc sc.Service, data *attestation.AttestationData, ztsUrl string, opts *sc.Options) error { - - keyFile := util.GetSvcKeyFileName(opts.KeyDir, svc.KeyFilename, opts.Domain, svc.Name) - certFile := util.GetSvcCertFileName(opts.CertDir, svc.CertFilename, opts.Domain, svc.Name) - - key, err := util.PrivateKey(keyFile, opts.RotateKey) - if err != nil { - log.Printf("Unable to read private key from %s, err: %v\n", keyFile, err) - return err - } - - //if ssh support is enabled then we need to generate the csr - //it is also generated for the primary service only - hostname := getServiceHostname(opts, svc, false) - sshCertRequest, sshCsr, err := generateSshRequest(opts, svc.Name, hostname) - if err != nil { - return err - } - - //if the user hasn't configured to include the san dns hostname - //then we're going to reset the hostname value to an empty string - if !opts.SanDnsHostname { - hostname = "" - } - svcCertReqOptions := &util.SvcCertReqOptions{ - Country: opts.CertCountryName, - OrgName: opts.CertOrgName, - Domain: opts.Domain, - Service: svc.Name, - CommonName: data.CommonName, - InstanceId: opts.InstanceId, - Provider: opts.Provider.GetName(), - Hostname: hostname, - SpiffeTrustDomain: opts.SpiffeTrustDomain, - SpiffeNamespace: opts.SpiffeNamespace, - AddlSanDNSEntries: opts.AddlSanDNSEntries, - ZtsDomains: opts.ZTSAWSDomains, - WildCardDnsName: opts.SanDnsWildcard, - InstanceIdSanDNS: opts.InstanceIdSanDNS, - } - csr, err := util.GenerateSvcCertCSR(key, svcCertReqOptions) - if err != nil { - log.Printf("Unable to generate CSR for %s, err: %v\n", opts.Name, err) - return err - } - attestData, err := json.Marshal(data) - if err != nil { - return err - } - - athenzJwk := true - athenzJwkModified := util.GetAthenzJwkConfModTime(siaMainDir) - - info := &zts.InstanceRefreshInformation{ - AttestationData: string(attestData), - Csr: csr, - Ssh: sshCsr, - SshCertRequest: sshCertRequest, - AthenzJWK: &athenzJwk, - AthenzJWKModified: &athenzJwkModified, - Hostname: zts.DomainName(hostname), - Namespace: zts.SimpleName(opts.SpiffeNamespace), - } - if svc.ExpiryTime > 0 && svc.ExpiryTime <= math.MaxInt32 { - expiryTime := int32(svc.ExpiryTime) - info.ExpiryTime = &expiryTime - } - - client, err := util.ZtsClient(ztsUrl, opts.ZTSServerName, keyFile, certFile, opts.ZTSCACertFile) - if err != nil { - log.Printf("Unable to get ZTS Client for %s, err: %v\n", ztsUrl, err) - return err - } - client.AddCredentials("User-Agent", opts.Version) - - ident, err := client.PostInstanceRefreshInformation(zts.ServiceName(opts.Provider.GetName()), zts.DomainName(opts.Domain), zts.SimpleName(svc.Name), zts.PathElement(opts.InstanceId), info) - if err != nil { - log.Printf("Unable to refresh instance service certificate for %s, err: %v\n", opts.Name, err) - return err - } - - svcKeyBytes := util.PrivatePem(key) - svcCertBytes := []byte(ident.X509Certificate) - serviceName := fmt.Sprintf("%s.%s", opts.Domain, svc.Name) - err = util.SaveServiceCertKey([]byte(svcKeyBytes), svcCertBytes, keyFile, certFile, serviceName, svc.Uid, svc.Gid, svc.FileMode, opts.RotateKey, opts.BackupDir, opts.FileDirectUpdate) - if err != nil { - return err - } - - if opts.Services[0].Name == svc.Name { - err = util.UpdateFile(opts.AthenzCACertFile, []byte(ident.X509CertificateSigner), svc.Uid, svc.Gid, 0444, opts.FileDirectUpdate, true) - if err != nil { - return err - } - } - //we're not going to count ssh updates as fatal since the primary - //task for sia to get service identity certs but we'll log the failure - if ident.SshCertificate != "" { - err = updateSSH(opts.SshCertFile, opts.SshConfigFile, ident.SshCertificate, opts.FileDirectUpdate) - if err != nil { - log.Printf("Unable to update ssh certificate, err: %v\n", err) - } - } - - if ident.AthenzJWK != nil { - err = util.WriteAthenzJWKFile(ident.AthenzJWK, siaMainDir, svc.Uid, svc.Gid) - if err != nil { - return err - } - } - return nil -} - -func generateSshRequest(opts *sc.Options, primaryServiceName, hostname string) (*zts.SSHCertRequest, string, error) { - var err error - var sshCsr string - var sshCertRequest *zts.SSHCertRequest - if opts.Ssh && opts.Services[0].Name == primaryServiceName { - if opts.SshHostKeyType == hostkey.Rsa { - sshCsr, err = util.GenerateSSHHostCSR(opts.SshPubKeyFile, opts.Domain, primaryServiceName, opts.PrivateIp, opts.ZTSAWSDomains) - } else { - sshPrincipals := opts.SshPrincipals - // additional ssh host principals are added on best effort basis, hence error below is ignored. - additionalSshHostPrincipals, _ := opts.Provider.GetAdditionalSshHostPrincipals(opts.MetaEndPoint) - if additionalSshHostPrincipals != "" { - if sshPrincipals != "" { - sshPrincipals = sshPrincipals + "," + additionalSshHostPrincipals - } else { - sshPrincipals = additionalSshHostPrincipals - } - } - sshCertRequest, err = util.GenerateSSHHostRequest(opts.SshPubKeyFile, opts.Domain, primaryServiceName, hostname, opts.PrivateIp, opts.InstanceId, sshPrincipals, opts.ZTSAWSDomains) - } - } - return sshCertRequest, sshCsr, err -} - -func restartSshdService() error { - return exec.Command(util.GetUtilPath("systemctl"), "restart", "sshd").Run() -} - -func updateSSH(sshCertFile, sshConfigFile, hostCert string, fileDirectUpdate bool) error { - - //write the host cert file - err := util.UpdateFile(sshCertFile, []byte(hostCert), 0, 0, 0644, fileDirectUpdate, true) - if err != nil { - return err - } - - //Now update the config file, if needed. The format of the line we're going - //to insert is HostCertificate . so we'll see if the line exists - //or not and if not we'll insert one at the end of the file - if sshConfigFile != "" { - configPresent, err := hostCertificateLinePresent(sshConfigFile, sshCertFile) - if err != nil { - log.Printf("unable to check host certificate line for %s - error %v\n", sshConfigFile, err) - return err - } - if configPresent { - return nil - } - //update the sshconfig file to include HostCertificate line - err = updateSSHConfigFile(sshConfigFile, sshCertFile) - if err != nil { - return err - } - //and restart sshd to notice the changes. - return restartSshdService() - } - return nil -} - -func updateSSHConfigFile(sshConfigFile, sshCertFile string) error { - //update the sshd config file to include HostCertificate line - file, err := os.OpenFile(sshConfigFile, os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - return err - } - defer file.Close() - certLine := fmt.Sprintf("\nHostCertificate %s\n", sshCertFile) - _, err = file.Write([]byte(certLine)) - if err != nil { - return err - } - return nil -} - -func hostCertificateLinePresent(sshConfigFile, sshCertFile string) (bool, error) { - - certLine := fmt.Sprintf("HostCertificate %s", sshCertFile) - file, err := os.Open(sshConfigFile) - if err != nil { - return false, err - } - defer file.Close() - scanner := bufio.NewScanner(file) - scanner.Split(bufio.ScanLines) - for scanner.Scan() { - line := strings.Trim(scanner.Text(), " \t") - if strings.HasPrefix(line, certLine) { - log.Printf("ssh configuration file already includes expected line: %s\n", line) - return true, nil - } - } - return false, nil -} - -func SetupAgent(opts *sc.Options, siaAgentDir, siaLinkDir string) { - - //first, let's determine if we need to drop our privileges - //since it requires us to create the directories with the - //specified ownership - runUid, runGid := options.GetRunsAsUidGid(opts) - - //make sure all component directories exist and have required ownership - err := util.SetupSIADir(siaAgentDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Agent directory '%s': %v\n", siaAgentDir, err) - } - //if we have a link directory specified then we'll create that as well - if siaLinkDir != "" && !util.FileExists(siaLinkDir) { - err = os.Symlink(siaAgentDir, siaLinkDir) - if err != nil { - log.Printf("Unable to symlink SIA directory '%s': %v\n", siaLinkDir, err) - } - } - if siaAgentDir != siaMainDir { - err = util.SetupSIADir(siaMainDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Main directory '%s': %v\n", siaMainDir, err) - } - } - err = util.SetupSIADir(opts.KeyDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Key directory '%s': %v\n", opts.KeyDir, err) - } - err = util.SetupSIADir(opts.CertDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Cert directory '%s': %v\n", opts.CertDir, err) - } - err = util.SetupSIADir(opts.TokenDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Token directory '%s': %v\n", opts.TokenDir, err) - } - err = util.SetupSIADir(opts.BackupDir, runUid, runGid) - if err != nil { - log.Printf("Unable to setup SIA Backup directory '%s': %v\n", opts.BackupDir, err) - } - - //check to see if we need to drop our privileges and - //run as the specific group id - if runGid != -1 { - if err := util.SyscallSetGid(runGid); err != nil { - log.Printf("unable to drop privileges to group %d, error: %v\n", runGid, err) - } - } - // same check for the user id - if runUid != -1 { - if err := util.SyscallSetUid(runUid); err != nil { - log.Printf("unable to drop privileges to user %d, error: %v\n", runUid, err) - } - } -} - -func RunAgent(siaCmds, ztsUrl string, opts *sc.Options) { - log.Printf("sia command line arguments specified: '%s'\n", siaCmds) - cmds := strings.Split(siaCmds, ",") - for _, cmd := range cmds { - runAgentCommand(cmd, ztsUrl, opts) - } -} - -func runAgentCommand(siaCmd, ztsUrl string, opts *sc.Options) { - - //make sure the meta endpoint is configured by the caller - if opts.MetaEndPoint == "" { - log.Fatalf("meta endpoint not configured") - } - - //the default value is to rotate once every day since our - //server and role certs are valid for 30 days by default - rotationInterval := time.Duration(opts.RefreshInterval) * time.Minute - - data, err := attestation.GetAttestationData(opts) - if err != nil { - log.Fatalf("Cannot determine identity to run as, err:%v\n", err) - } - svcs := options.GetSvcNames(opts.Services) - - tokenOpts, err := tokenOptions(opts, ztsUrl) - if err != nil { - log.Printf(err.Error()) - } - cmd, skipErrors := util.ParseSiaCmd(siaCmd) - switch cmd { - case "rolecert": - count, failures := GetRoleCertificates(ztsUrl, opts) - if len(failures) != 0 { - util.ExecuteScript(opts.RunAfterCertsErrParts, strings.Join(failures, ","), false) - if !skipErrors { - log.Fatalf("unable to fetch %d out of %d requested role certificates\n", len(failures), count) - } - } - if count != 0 { - util.ExecuteScript(opts.RunAfterCertsOkParts, "", opts.RunAfterFailExit) - } - util.TouchDoneFile(siaMainDir, "rolecert") - case "token": - if tokenOpts != nil { - err := fetchAccessToken(tokenOpts) - if err != nil { - util.ExecuteScript(opts.RunAfterTokensErrParts, err.Error(), false) - if !skipErrors { - log.Fatalf("Unable to fetch access tokens, err: %v\n", err) - } - } - util.ExecuteScript(opts.RunAfterTokensOkParts, "", opts.RunAfterFailExit) - } else { - log.Print("unable to fetch access tokens, invalid or missing configuration") - } - util.TouchDoneFile(siaMainDir, "token") - case "post", "register": - err := RegisterInstance(data, ztsUrl, opts, false) - if err != nil { - log.Fatalf("Unable to register identity, err: %v\n", err) - } - util.ExecuteScript(opts.RunAfterCertsOkParts, "", opts.RunAfterFailExit) - util.TouchDoneFile(siaMainDir, "register") - log.Printf("identity registered for services: %s\n", svcs) - case "rotate", "refresh": - err = RefreshInstance(data, ztsUrl, opts) - if err != nil { - log.Fatalf("Refresh identity failed, err: %v\n", err) - } - util.ExecuteScript(opts.RunAfterCertsOkParts, "", opts.RunAfterFailExit) - util.TouchDoneFile(siaMainDir, "refresh") - log.Printf("Identity successfully refreshed for services: %s\n", svcs) - case "init": - err := RegisterInstance(data, ztsUrl, opts, false) - if err != nil { - log.Fatalf("Unable to register identity, err: %v\n", err) - } - log.Printf("identity registered for services: %s\n", svcs) - count, failures := GetRoleCertificates(ztsUrl, opts) - if len(failures) != 0 { - util.ExecuteScript(opts.RunAfterCertsErrParts, strings.Join(failures, ","), false) - if !skipErrors { - log.Fatalf("unable to fetch %d out of %d requested role certificates\n", len(failures), count) - } - } - util.ExecuteScript(opts.RunAfterCertsOkParts, "", opts.RunAfterFailExit) - if tokenOpts != nil { - err := fetchAccessToken(tokenOpts) - if err != nil { - util.ExecuteScript(opts.RunAfterTokensErrParts, err.Error(), false) - if !skipErrors { - log.Fatalf("Unable to fetch access tokens, err: %v\n", err) - } - } - util.ExecuteScript(opts.RunAfterTokensOkParts, "", opts.RunAfterFailExit) - } - util.TouchDoneFile(siaMainDir, "init") - default: - // we're going to iterate through our configured services. - // if the service key and certificate files exist then we're - // going to refresh the identity, otherwise we're going to - // register it. before registration, we'll verify that we - // haven't passed our 30-min server enforced timeout since - // there is no point to contact ZTS if it's going to reject it - // for any refresh operations, we're going to skip any failures - // since the existing file on disk is still valid, and we can - // refresh during the next daily run. - initialSetup := true - for i, svc := range opts.Services { - if serviceAlreadyRegistered(opts, svc) { - err = refreshSvc(svc, data[i], ztsUrl, opts) - if err != nil { - log.Printf("unable to refresh identity for svc: %q, error: %v", svc.Name, err) - } - } else { - if shouldSkipRegister(opts) { - log.Fatalf("identity document has expired (30 min timeout). ZTS will not register this instance. Please relaunch or stop and start your instance to refesh its identity document") - } - err = registerSvc(svc, data[i], ztsUrl, opts) - if err != nil { - log.Fatalf("unable to register identity for svc: %q, error: %v", svc.Name, err) - } - } - } - - util.NotifySystemdReadyForCommand(cmd, "systemd-notify") - log.Printf("Identity established for services: %s\n", svcs) - - stop := make(chan bool, 1) - errors := make(chan error, 1) - certUpdates := make(chan bool, 1) - - // keep track of failed counts for refresh operations. Since we typically - // get certs valid for several days, there is no need to exit immediately - // and keep retrying. instead, we'll just skip this run and retry again - // in the configured number of minutes (based on the refresh interval) - failedRefreshCount := 0 - - go func() { - for { - // if we just did our initial setup there is no point - // to refresh the certs again. so we are going to skip - // this time around and refresh certs next time - - if !initialSetup { - data, err := attestation.GetAttestationData(opts) - if err != nil { - errors <- fmt.Errorf("Cannot get attestation data: %v\n", err) - return - } - err = RefreshInstance(data, ztsUrl, opts) - if err != nil { - failedRefreshCount++ - if shouldExitRightAway(failedRefreshCount, opts) { - errors <- fmt.Errorf("refresh identity failed: %v\n", err) - return - } else { - util.ExecuteScriptWithoutBlock(opts.RunAfterCertsErrParts, svcs, false) - log.Printf("refresh identity failed for svcs %s, error: %v\n", svcs, err) - log.Printf("refresh will be retried in %d minutes, failure %d of %d\n", opts.RefreshInterval, failedRefreshCount, opts.FailCountForExit) - } - } else { - failedRefreshCount = 0 - log.Printf("identity successfully refreshed for services: %s\n", svcs) - } - } - initialSetup = false - if tokenOpts != nil { - err := accessTokenRequest(tokenOpts) - if err != nil { - util.ExecuteScriptWithoutBlock(opts.RunAfterTokensErrParts, err.Error(), false) - } else { - util.ExecuteScriptWithoutBlock(opts.RunAfterTokensOkParts, "", opts.RunAfterFailExit) - } - } else { - log.Print("token config does not exist - do not refresh token") - } - _, failures := GetRoleCertificates(ztsUrl, opts) - if len(failures) != 0 { - util.ExecuteScriptWithoutBlock(opts.RunAfterCertsErrParts, strings.Join(failures, ","), false) - } - util.ExecuteScriptWithoutBlock(opts.RunAfterCertsOkParts, "", opts.RunAfterFailExit) - util.NotifySystemdReadyForCommand(cmd, "systemd-notify-all") - - if opts.SDSUdsPath != "" { - certUpdates <- true - } - - select { - case <-stop: - errors <- nil - return - case <-time.After(rotationInterval): - break - } - } - }() - - go func() { - if opts.SDSUdsPath != "" { - err := sds.StartGrpcServer(opts, certUpdates) - if err != nil { - log.Printf("failed to start grpc/uds server: %v\n", err) - stop <- true - return - } - } - }() - - go func() { - signals := make(chan os.Signal, 2) - signal.Notify(signals, os.Interrupt, syscall.SIGTERM) - sig := <-signals - log.Printf("Received signal %v, stopping rotation\n", sig) - stop <- true - }() - - go func() { - if tokenOpts == nil || tokenOpts.TokenRefresh == 0 { - return - } - - log.Printf("start refresh access-token task every [%s]", fmt.Sprint(tokenOpts.TokenRefresh)) - t2 := time.NewTicker(tokenOpts.TokenRefresh) - defer t2.Stop() - for { - select { - case <-t2.C: - log.Printf("refreshing access-token..") - err := accessTokenRequest(tokenOpts) - if err != nil { - util.ExecuteScriptWithoutBlock(opts.RunAfterTokensErrParts, err.Error(), false) - } else { - util.ExecuteScriptWithoutBlock(opts.RunAfterTokensOkParts, "", opts.RunAfterFailExit) - } - case <-stop: - errors <- nil - return - } - } - }() - - err = <-errors - if err != nil { - log.Printf("%v\n", err) - } - } -} - -func accessTokenRequest(tokenOpts *config.TokenOptions) error { - // getExponentialBackoffToken will return a backoff config with first retry delay of 5s, and backoff retry - // until params.tokenRefresh / 4 - getExponentialBackoffToken := func() *backoff.ExponentialBackOff { - b := backoff.NewExponentialBackOff() - b.InitialInterval = 5 * time.Second - b.Multiplier = 2 - b.MaxElapsedTime = tokenOpts.TokenRefresh / 4 - return b - } - - notifyOnAccessTokenErr := func(err error, backoffDelay time.Duration) { - log.Printf("Failed to create/refresh access token: %s. Retrying in %s", err.Error(), backoffDelay) - } - - accessTokenFunc := func() error { - return fetchAccessToken(tokenOpts) - } - err := backoff.RetryNotify(accessTokenFunc, getExponentialBackoffToken(), notifyOnAccessTokenErr) - - if err != nil { - log.Printf("access tokens errors: %v", err) - } - return err -} - -func tokenOptions(opts *sc.Options, ztsUrl string) (*config.TokenOptions, error) { - userAgent := fmt.Sprintf("%s-%s", opts.Provider, opts.InstanceId) - tokenOpts, err := tokens.NewTokenOptions(opts, ztsUrl, userAgent) - if err != nil { - return nil, fmt.Errorf("processing access tokens: %s", err.Error()) - } - if opts.StoreTokenOption != nil { - tokenOpts.StoreOptions = config.StoreTokenOptions(*opts.StoreTokenOption) - } else { - tokenOpts.StoreOptions = config.AccessTokenProp - } - - log.Printf("token options created successfully") - return tokenOpts, nil -} - -func fetchAccessToken(tokenOpts *config.TokenOptions) error { - - _, errs := tokens.Fetch(tokenOpts) - log.Printf("Fetch access token completed successfully with [%d] errors", len(errs)) - - switch len(errs) { - case 0: - return nil - case 1: - return errs[0] - default: - var errsStr []string - for _, er := range errs { - errsStr = append(errsStr, er.Error()) - } - return fmt.Errorf(strings.Join(errsStr, ",")) - } -} - -func shouldSkipRegister(opts *sc.Options) bool { - if opts.EC2StartTime == nil { - return false - } - duration := time.Since(*opts.EC2StartTime) - //our server timeout is 30 mins = 1800 secs - return duration.Seconds() > 1800 -} - -func serviceAlreadyRegistered(opts *sc.Options, svc sc.Service) bool { - keyFile := util.GetSvcKeyFileName(opts.KeyDir, svc.KeyFilename, opts.Domain, svc.Name) - certFile := util.GetSvcCertFileName(opts.CertDir, svc.CertFilename, opts.Domain, svc.Name) - return util.FileExists(keyFile) && util.FileExists(certFile) -} - -func shouldExitRightAway(failedRefreshCount int, opts *sc.Options) bool { - // if the failed count already matches or exceeds our configured - // value then we return right away - if failedRefreshCount >= opts.FailCountForExit { - return true - } - // if the count hasn't reached the limit, we will skip this - // failure only if all the certificates that we're refreshing - // are not going to expire before the next refresh happens - for _, svc := range opts.Services { - svcCertFile := util.GetSvcCertFileName(opts.CertDir, svc.CertFilename, opts.Domain, svc.Name) - // if we're not able to parse/load the certificate file, we'll exit right away - x509Cert, err := athenzutils.LoadX509Certificate(svcCertFile) - if err != nil { - return true - } - if x509Cert.NotAfter.Unix()-time.Now().Unix() < int64(opts.RefreshInterval*60) { - return true - } - } - return false -} diff --git a/libs/go/sia/aws/agent/agent_test.go b/libs/go/sia/aws/agent/agent_test.go deleted file mode 100644 index 81519fbe982..00000000000 --- a/libs/go/sia/aws/agent/agent_test.go +++ /dev/null @@ -1,687 +0,0 @@ -// -// Copyright The Athenz Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package agent - -import ( - "crypto" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "fmt" - "log" - "net" - "net/url" - "os" - "testing" - "time" - - "github.com/AthenZ/athenz/libs/go/sia/access/config" - "github.com/AthenZ/athenz/libs/go/sia/aws/agent/devel/ztsmock" - "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" - sc "github.com/AthenZ/athenz/libs/go/sia/config" - "github.com/AthenZ/athenz/libs/go/sia/host/ip" - "github.com/AthenZ/athenz/libs/go/sia/host/signature" - "github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey" - "github.com/AthenZ/athenz/libs/go/sia/util" - "github.com/stretchr/testify/assert" -) - -func setup() { - go ztsmock.StartZtsServer("127.0.0.1:5084") - time.Sleep(3 * time.Second) -} - -func teardown() {} - -func TestMain(m *testing.M) { - setup() - code := m.Run() - teardown() - os.Exit(code) -} - -type TestProvider struct { - Name string - Hostname string -} - -// GetName returns the name of the current provider -func (tp TestProvider) GetName() string { - return tp.Name -} - -// GetHostname returns the hostname as per the provider -func (tp TestProvider) GetHostname(bool) string { - return tp.Hostname -} - -func (tp TestProvider) AttestationData(string, crypto.PrivateKey, *signature.SignatureInfo) (string, error) { - return "", fmt.Errorf("not implemented") -} - -func (tp TestProvider) PrepareKey(string) (crypto.PrivateKey, error) { - return "", fmt.Errorf("not implemented") -} - -func (tp TestProvider) GetCsrDn() pkix.Name { - return pkix.Name{} -} - -func (tp TestProvider) GetSanDns(string, bool, bool, []string) []string { - return nil -} - -func (tp TestProvider) GetSanUri(string, ip.Opts, string, string) []*url.URL { - return nil -} - -func (tp TestProvider) GetEmail(string) []string { - return nil -} - -func (tp TestProvider) GetRoleDnsNames(*x509.Certificate, string) []string { - return nil -} - -func (tp TestProvider) GetSanIp(map[string]bool, []net.IP, ip.Opts) []net.IP { - return nil -} - -func (tp TestProvider) GetSuffixes() []string { - return []string{} -} - -func (tp TestProvider) CloudAttestationData(string, string, string) (string, error) { - a, _ := json.Marshal(&attestation.AttestationData{ - Role: "athenz.hockey", - CommonName: "athenz.hockey", - }) - - return string(a), nil -} - -func (tp TestProvider) GetAccountDomainServiceFromMeta(string) (string, string, string, error) { - return "testAcct", "testDom", "testSvc", nil -} - -func (tp TestProvider) GetAccessManagementProfileFromMeta(string) (string, error) { - return "testProf", nil -} - -func (tp TestProvider) GetAdditionalSshHostPrincipals(string) (string, error) { - return "i-1234edt22", nil -} - -func TestUpdateFileNew(test *testing.T) { - testInternalUpdateFileNew(test, true) - testInternalUpdateFileNew(test, false) -} - -func testInternalUpdateFileNew(test *testing.T, fileDirectUpdate bool) { - - //make sure our temp file does not exist - timeNano := time.Now().UnixNano() - fileName := fmt.Sprintf("sia-test.tmp%d", timeNano) - _ = os.Remove(fileName) - testContents := "sia-unit-test" - err := util.UpdateFile(fileName, []byte(testContents), util.ExecIdCommand("-u"), util.ExecIdCommand("-g"), 0644, fileDirectUpdate, true) - if err != nil { - test.Errorf("Cannot create new file: %v", err) - return - } - data, err := os.ReadFile(fileName) - if err != nil { - test.Errorf("Cannot read new created file: %v", err) - _ = os.Remove(fileName) - return - } - if string(data) != testContents { - test.Errorf("Read %s data not the same as stored %s data", data, testContents) - _ = os.Remove(fileName) - return - } - _ = os.Remove(fileName) -} - -func TestUpdateFileExisting(test *testing.T) { - testInternalUpdateFileExisting(test, true) - testInternalUpdateFileExisting(test, false) -} - -func testInternalUpdateFileExisting(test *testing.T, fileDirectUpdate bool) { - - //create our temporary file - timeNano := time.Now().UnixNano() - fileName := fmt.Sprintf("sia-test.tmp%d", timeNano) - testContents := "sia-unit-test" - err := os.WriteFile(fileName, []byte(testContents), 0644) - if err != nil { - test.Errorf("Cannot create new file: %v", err) - return - } - testNewContents := "sia-unit" - err = util.UpdateFile(fileName, []byte(testNewContents), util.ExecIdCommand("-u"), util.ExecIdCommand("-g"), 0644, fileDirectUpdate, true) - if err != nil { - test.Errorf("Cannot create new file: %v", err) - return - } - data, err := os.ReadFile(fileName) - if err != nil { - test.Errorf("Cannot read new created file: %v", err) - _ = os.Remove(fileName) - return - } - if string(data) != testNewContents { - test.Errorf("Read %s data not the same as stored %s data", data, testNewContents) - _ = os.Remove(fileName) - return - } - _ = os.Remove(fileName) -} - -func TestRegisterInstance(test *testing.T) { - - siaDir := test.TempDir() - - keyFile := fmt.Sprintf("%s/athenz.hockey.key.pem", siaDir) - certFile := fmt.Sprintf("%s/athenz.hockey.cert.pem", siaDir) - caCertFile := fmt.Sprintf("%s/ca.cert.pem", siaDir) - - tp := TestProvider{ - Name: "athenz.aws.us-west-2", - } - opts := &sc.Options{ - Domain: "athenz", - Services: []sc.Service{ - { - Name: "hockey", - Uid: util.ExecIdCommand("-u"), - Gid: util.ExecIdCommand("-g"), - }, - }, - KeyDir: siaDir, - CertDir: siaDir, - AthenzCACertFile: caCertFile, - ZTSAWSDomains: []string{"zts-aws-cloud"}, - Region: "us-west-2", - InstanceId: "pod-1234", - Provider: tp, - SanDnsHostname: true, - } - - a := &attestation.AttestationData{ - Role: "athenz.hockey", - CommonName: "athenz.hockey", - } - - err := RegisterInstance([]*attestation.AttestationData{a}, "http://127.0.0.1:5084/zts/v1", opts, false) - assert.Nil(test, err, "unable to register instance") - - if err != nil { - test.Errorf("Unable to register instance: %v", err) - return - } - _, err = os.Stat(keyFile) - if err != nil { - test.Errorf("Unable to validate private key file: %v", err) - } - _, err = os.Stat(certFile) - if err != nil { - test.Errorf("Unable to validate x509 certificate file: %v", err) - } - _, err = os.Stat(caCertFile) - if err != nil { - test.Errorf("Unable to validate CA certificate file: %v", err) - } -} - -func copyFile(src, dst string) error { - data, err := os.ReadFile(src) - if err != nil { - return err - } - return os.WriteFile(dst, data, 0644) -} - -func refreshServiceCertSetup(test *testing.T) (*sc.Options, *attestation.AttestationData, string) { - - siaDir := test.TempDir() - - keyFile := fmt.Sprintf("%s/athenz.hockey.key.pem", siaDir) - certFile := fmt.Sprintf("%s/athenz.hockey.cert.pem", siaDir) - caCertFile := fmt.Sprintf("%s/ca.cert.pem", siaDir) - - err := copyFile("devel/data/key.pem", keyFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/key.pem", keyFile, err) - return nil, nil, "" - } - err = copyFile("devel/data/cert.pem", certFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/cert.pem", certFile, err) - return nil, nil, "" - } - err = copyFile("devel/data/ca.cert.pem", caCertFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/ca.cert..pem", caCertFile, err) - return nil, nil, "" - } - - tp := TestProvider{ - Name: "athenz.aws.us-west-2", - } - opts := &sc.Options{ - Domain: "athenz", - Services: []sc.Service{ - { - Name: "hockey", - Uid: util.ExecIdCommand("-u"), - Gid: util.ExecIdCommand("-g"), - FileMode: 0400, - }, - }, - KeyDir: siaDir, - CertDir: siaDir, - AthenzCACertFile: caCertFile, - Provider: tp, - ZTSAWSDomains: []string{"zts-aws-cloud"}, - Region: "us-west-2", - InstanceId: "pod-1234", - } - - attestationData := &attestation.AttestationData{ - Role: "athenz.hockey", - CommonName: "athenz.hockey", - } - - return opts, attestationData, certFile -} - -func TestRefreshInstance(test *testing.T) { - - opts, attestationData, certFile := refreshServiceCertSetup(test) - if opts == nil || attestationData == nil { - test.Errorf("Certificate setup was not completed successfully") - return - } - - err := RefreshInstance([]*attestation.AttestationData{attestationData}, "http://127.0.0.1:5084/zts/v1", opts) - assert.Nil(test, err, fmt.Sprintf("unable to refresh instance: %v", err)) - - oldCert, _ := os.ReadFile("devel/data/cert.pem") - newCert, _ := os.ReadFile(certFile) - if string(oldCert) == string(newCert) { - test.Errorf("Certificate was not refreshed") - return - } -} - -func TestRoleCertificateRequest(test *testing.T) { - - siaDir := test.TempDir() - - keyFile := fmt.Sprintf("%s/athenz.hockey.key.pem", siaDir) - certFile := fmt.Sprintf("%s/athenz.hockey.cert.pem", siaDir) - caCertFile := fmt.Sprintf("%s/ca.cert.pem", siaDir) - roleCertFile := fmt.Sprintf("%s/testrole.cert.pem", siaDir) - - err := copyFile("devel/data/key.pem", keyFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/key.pem", keyFile, err) - return - } - err = copyFile("devel/data/cert.pem", certFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/cert.pem", certFile, err) - return - } - err = copyFile("devel/data/ca.cert.pem", caCertFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/ca.cert..pem", caCertFile, err) - return - } - - tp := TestProvider{ - Name: "athenz.aws.us-west-2", - } - opts := &sc.Options{ - Domain: "athenz", - Services: []sc.Service{ - { - Name: "hockey", - Uid: util.ExecIdCommand("-u"), - Gid: util.ExecIdCommand("-g"), - FileMode: 0400, - }, - }, - Roles: []sc.Role{ - { - Name: "athenz:role.writers", - Service: "hockey", - Uid: util.ExecIdCommand("-u"), - Gid: util.ExecIdCommand("-g"), - RoleCertFilename: roleCertFile, - FileMode: 0400, - }, - }, - KeyDir: siaDir, - CertDir: siaDir, - AthenzCACertFile: caCertFile, - ZTSAWSDomains: []string{"zts-aws-cloud"}, - Provider: tp, - } - - _, failures := GetRoleCertificates("http://127.0.0.1:5084/zts/v1", opts) - if len(failures) != 0 { - test.Errorf("Unable to get role certificate: %v", err) - return - } - - _, err = os.Stat(roleCertFile) - if err != nil { - test.Errorf("Unable to validate role certificate file: %v", err) - } -} - -func TestShouldSkipRegister(test *testing.T) { - startTime := time.Now() - opts := &sc.Options{ - EC2StartTime: &startTime, - } - //current time is valid - if shouldSkipRegister(opts) { - test.Errorf("Current time is considered expired incorrectly") - } - //generate time stamp 29 mins ago - valid - startTime = time.Now().Add(time.Minute * 29 * -1) - opts.EC2StartTime = &startTime - if shouldSkipRegister(opts) { - test.Errorf("29 mins ago time is considered expired incorrectly") - } - //generate time stamp 31 mins ago - expired - startTime = time.Now().Add(time.Minute * 31 * -1) - opts.EC2StartTime = &startTime - if !shouldSkipRegister(opts) { - test.Errorf("31 mins ago time is considered not expired incorrectly") - } -} - -func TestHostCertificateLinePresent(test *testing.T) { - tests := []struct { - name string - data string - certFile string - result bool - }{ - {"valid-start", "HostCertificate /sshd.config", "/sshd.config", true}, - {"valid-mid", "PermitTunnel no\nHostCertificate /sshd.config\nUseDNS no", "/sshd.config", true}, - {"valid-mid-space", "PermitTunnel no\n HostCertificate /sshd.config\nUseDNS no", "/sshd.config", true}, - {"valid-mid-tab", "PermitTunnel no\n\tHostCertificate /sshd.config\nUseDNS no", "/sshd.config", true}, - {"valid-mid-mix", "PermitTunnel no\n \t HostCertificate /sshd.config\nUseDNS no", "/sshd.config", true}, - {"valid-end", "PermitTunnel no\nHostCertificate /sshd.config", "/sshd.config", true}, - {"valid-commented", "PermitTunnel no\n#HostCertificate /sshd.config\nUseDNS no", "/sshd.config", false}, - {"valid-not-present1", "PermitTunnel no\nHostCertificateOther /sshd.config\nUseDNS no", "/sshd.config", false}, - {"valid-not-present2", "PermitTunnel no\nHostCertificate/sshd.config\nUseDNS no", "/sshd.config", false}, - {"valid-not-present3", "PermitTunnel no\n\nUseDNS no\n", "/sshd.config", false}, - {"valid-not-present3", "PermitTunnel no\nHostCertificate /sshd2.config\nUseDNS no\n", "/sshd.config", false}, - } - for _, tt := range tests { - test.Run(tt.name, func(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "sia-agent-test-") - if err != nil { - log.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - os.WriteFile(tmpFile.Name(), []byte(tt.data), 644) - result, _ := hostCertificateLinePresent(tmpFile.Name(), tt.certFile) - if result != tt.result { - test.Errorf("%s: invalid value returned - expected: %v, received %v", tt.name, tt.result, result) - } - }) - } -} - -func TestUpdateSSHConfigFile(test *testing.T) { - tests := []struct { - name string - data string - result string - }{ - {"test1", "PermitTunnel no\nUseDNS no", "PermitTunnel no\nUseDNS no\nHostCertificate /sshd.config\n"}, - {"test2", "PermitTunnel no\n#HostCertificate /sshd.config\nUseDNS no\n", "PermitTunnel no\n#HostCertificate /sshd.config\nUseDNS no\n\nHostCertificate /sshd.config\n"}, - } - for _, tt := range tests { - test.Run(tt.name, func(t *testing.T) { - tmpFile, err := os.CreateTemp(os.TempDir(), "sia-agent-test-") - if err != nil { - log.Fatal("Cannot create temporary file", err) - } - defer os.Remove(tmpFile.Name()) - os.WriteFile(tmpFile.Name(), []byte(tt.data), 644) - err = updateSSHConfigFile(tmpFile.Name(), "/sshd.config") - if err != nil { - test.Errorf("%s: unable to update file %s - error: %v", tt.name, tmpFile.Name(), err) - } - data, _ := os.ReadFile(tmpFile.Name()) - if tt.result != string(data) { - test.Errorf("%s: invalid value returned - expected: %v, received %v", tt.name, tt.result, string(data)) - } - }) - } -} - -func TestNilTokenOptions(test *testing.T) { - opts := &sc.Options{ - Domain: "athenz", - } - token, err := tokenOptions(opts, "") - assert.Nil(test, token, "should not create token") - assert.NotNil(test, err, "token is not presented") -} - -func TestTokenStoreOptions(test *testing.T) { - opts := &sc.Options{ - Domain: "athenz", - AccessTokens: []config.AccessToken{ - { - FileName: "reader", - Domain: "athenz", - Service: "api", - }, - }, - TokenDir: "/tmp", - CertDir: "/tmp", - KeyDir: "/tmp", - BackupDir: "/tmp", - } - token, err := tokenOptions(opts, "") - assert.Nil(test, err) - assert.Equal(test, token.StoreOptions, config.AccessTokenProp) - - // set the token option value - tokenOption := 2 - opts.StoreTokenOption = &tokenOption - - token, err = tokenOptions(opts, "") - assert.Nil(test, err) - assert.Equal(test, token.StoreOptions, config.AccessTokenWithoutQuotesProp) -} - -func TestGetServiceHostname(test *testing.T) { - tests := []struct { - name string - sanDnsHostname bool - providerHostname string - hostnameSuffix string - service string - domain string - result string - }{ - {"disabled", false, "unknown", "unknown", "api", "sports", ""}, - {"no hostname", true, "", "unknown", "api", "sports", ""}, - {"valid hostname", true, "zts.athenz.cloud", "unknown", "api", "sports", "zts.athenz.cloud"}, - {"no suffix", true, "zts", "", "api", "athenz", ""}, - {"autogenerated - top domain", true, "zts", "athenz.cloud", "api", "sports", "zts.api.sports.athenz.cloud"}, - {"autogenerated - subdomain", true, "zts", "athenz.cloud", "api", "sports.prod", "zts.api.sports-prod.athenz.cloud"}, - } - for _, tt := range tests { - test.Run(tt.name, func(t *testing.T) { - provider := TestProvider{ - Name: "testProvider", - Hostname: tt.providerHostname, - } - opts := sc.Options{ - SanDnsHostname: tt.sanDnsHostname, - HostnameSuffix: tt.hostnameSuffix, - Domain: tt.domain, - Provider: provider, - } - svc := sc.Service{ - Name: tt.service, - } - hostname := getServiceHostname(&opts, svc, false) - if tt.result != hostname { - test.Errorf("%s: invalid value returned - expected: %v, received %v", tt.name, tt.result, hostname) - } - }) - } -} - -func TestServiceAlreadyRegistered(test *testing.T) { - - keyDir := test.TempDir() - certDir := test.TempDir() - opts := sc.Options{ - KeyDir: keyDir, - CertDir: certDir, - Domain: "athenz", - } - keyFile := fmt.Sprintf("%s/athenz.hockey.key.pem", keyDir) - certFile := fmt.Sprintf("%s/athenz.hockey.cert.pem", certDir) - err := copyFile("devel/data/key.pem", keyFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/key.pem", keyFile, err) - return - } - err = copyFile("devel/data/cert.pem", certFile) - if err != nil { - test.Errorf("Unable to copy file %s to %s - %v\n", "devel/data/cert.pem", certFile, err) - return - } - tests := []struct { - name string - keyFileName string - certFileName string - result bool - }{ - {"both-valid", keyFile, certFile, true}, - {"key-valid-only", keyFile, "", false}, - {"cert-valid only", "", certFile, false}, - {"both-invalid", "", "", false}, - } - for _, tt := range tests { - test.Run(tt.name, func(t *testing.T) { - svc := sc.Service{ - Name: "api", - KeyFilename: tt.keyFileName, - CertFilename: tt.certFileName, - } - serviceRegistered := serviceAlreadyRegistered(&opts, svc) - if tt.result != serviceRegistered { - test.Errorf("%s: invalid value returned - expected: %v, received %v", tt.name, tt.result, serviceRegistered) - } - }) - } -} - -func TestGenerateSshRequest(test *testing.T) { - - tp := TestProvider{ - Name: "athenz.aws.us-west-2", - } - opts := sc.Options{ - Ssh: false, - Provider: tp, - } - // ssh option false we should get success with nils and empty csr - sshReq, sshCsr, err := generateSshRequest(&opts, "backend", "hostname.athenz.io") - assert.Nil(test, sshReq) - assert.Equal(test, "", sshCsr) - assert.Nil(test, err) - // ssh enabled but not for primary service we should get success with nils and empty csr - opts.Ssh = true - opts.Services = []sc.Service{ - { - Name: "api", - }, - } - sshReq, sshCsr, err = generateSshRequest(&opts, "backend", "hostname.athenz.io") - assert.Nil(test, sshReq) - assert.Equal(test, "", sshCsr) - assert.Nil(test, err) - // ssh enabled with primary service and key type is rsa - null cert request but valid csr - opts.SshPubKeyFile = "devel/data/cert.pem" - opts.Domain = "athenz" - opts.ZTSAWSDomains = []string{"athenz.io"} - opts.SshHostKeyType = hostkey.Rsa - sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io") - assert.Nil(test, sshReq) - assert.NotEmpty(test, sshCsr) - assert.Nil(test, err) - // ssh enabled with primary service and key type is ecdsa - empty csr but not-nil cert request - opts.SshHostKeyType = hostkey.Ecdsa - sshReq, sshCsr, err = generateSshRequest(&opts, "api", "hostname.athenz.io") - assert.NotNil(test, sshReq) - assert.Empty(test, sshCsr) - assert.Nil(test, err) -} - -func TestShouldExitRightAwayCountsOnly(test *testing.T) { - - opts := &sc.Options{ - FailCountForExit: 2, - } - - assert.True(test, shouldExitRightAway(2, opts)) - assert.True(test, shouldExitRightAway(3, opts)) - assert.False(test, shouldExitRightAway(0, opts)) - assert.False(test, shouldExitRightAway(1, opts)) -} - -func TestShouldExitRightAwayCertificate(test *testing.T) { - - opts, attestationData, _ := refreshServiceCertSetup(test) - if opts == nil || attestationData == nil { - test.Errorf("Certificate setup was not completed successfully") - return - } - - err := RefreshInstance([]*attestation.AttestationData{attestationData}, "http://127.0.0.1:5084/zts/v1", opts) - assert.Nil(test, err, fmt.Sprintf("unable to refresh instance: %v", err)) - - // our certs are valid for 30 days, so we'll set the refresh - // interval to 31 days, it should fail, but if we set it - // to 28 days, it should be ok - - opts.FailCountForExit = 2 - opts.RefreshInterval = 28 * 24 * 60 - assert.False(test, shouldExitRightAway(1, opts)) - - opts.RefreshInterval = 31 * 24 * 60 - assert.True(test, shouldExitRightAway(1, opts)) - -} diff --git a/libs/go/sia/aws/agent/devel/data/ca.cert.pem b/libs/go/sia/aws/agent/devel/data/ca.cert.pem deleted file mode 100644 index 154b92aca9d..00000000000 --- a/libs/go/sia/aws/agent/devel/data/ca.cert.pem +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDnDCCAoSgAwIBAgIRALKqcOsUOQhttiW8ni3kescwDQYJKoZIhvcNAQEFBQAw -dzELMAkGA1UEBhMCVVMxDzANBgNVBAgTBk9yZWdvbjERMA8GA1UEBxMIU3RhZmZv -cmQxDTALBgNVBAoTBFRyb3kxIzAhBgNVBAsTGlRyb3kgQ2VydGlmaWNhdGUgQXV0 -aG9yaXR5MRAwDgYDVQQDEwdUcm95IENBMB4XDTE3MDkxNjAwMTE0MFoXDTE4MDkx -NjAwMTE0MFowdzELMAkGA1UEBhMCVVMxDzANBgNVBAgTBk9yZWdvbjERMA8GA1UE -BxMIU3RhZmZvcmQxDTALBgNVBAoTBFRyb3kxIzAhBgNVBAsTGlRyb3kgQ2VydGlm -aWNhdGUgQXV0aG9yaXR5MRAwDgYDVQQDEwdUcm95IENBMIIBIjANBgkqhkiG9w0B -AQEFAAOCAQ8AMIIBCgKCAQEApEpwofGMq7vKyApa7f4/wzLMjl/O7EXDPPNygddh -Xc16OJ4qsFChPKkFEPa9+XjNQ6Ie+t7AVYZnn9Vlfxsyjilh5yvNkWC+gAHEFxZV -+4iAXuz5o1wYK67IS2RDLWDRGDoqLbYxclnpznkNkzOArP9Em8vKiYiYgmj5q1KD -9q1h3yrn1XZ8JSZV52tU8xFsTId+Pfd2BXyDcxU+yMeRfVGRroXpeFF2Ovji/rG8 -ZJfsQqUijxiT4f2CR68LTBQ6sP2GiEI2kpwh2zw/rV9JFB2dMOCxCS+KX052EcG9 -9YcZRvh394MRbUtYaW1oB/UB7YD8zzhljFvJdVY1LsDqKwIDAQABoyMwITAOBgNV -HQ8BAf8EBAMCAaYwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEA -X/Be9cMla+oT1gbK+b7mwKn1NwCFlsDOIOtoeBX2F7f8hSutKpiZxSNqrLS9haSf -t2m43PkZyRxclGDg++AsLCSS9oKxJ1eg1Nh+3lN2fiwYZZGvXjspycBfNbJsEYer -0b6v2lY9qpecQ2e6Cvy5i58WYQjshsOim9UfuQtu1/2d5I+UfOvpR+7dxq5KwNbp -SBjiR9b1jfj5B83bdsBbblZDtnwCnR/q/OOrGaIzGgwLrmqUY3z8RbvL/Ng/slki -Aqpq5CDLPk2J7NjfCC3OXVWu8M/Qw1vlqhyx0UdchqArFnsjzWJBFNu/6iCkQihd -/UMhfcP7cYAU9RfoI06tGg== ------END CERTIFICATE----- diff --git a/libs/go/sia/aws/agent/devel/data/cert.pem b/libs/go/sia/aws/agent/devel/data/cert.pem deleted file mode 100644 index 9f27cb9fbe8..00000000000 --- a/libs/go/sia/aws/agent/devel/data/cert.pem +++ /dev/null @@ -1,24 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIECTCCAvGgAwIBAgIRAMPHWJCjbWwifaM+FOqzEScwDQYJKoZIhvcNAQELBQAw -dzELMAkGA1UEBhMCVVMxDzANBgNVBAgTBk9yZWdvbjERMA8GA1UEBxMIU3RhZmZv -cmQxDTALBgNVBAoTBFRyb3kxIzAhBgNVBAsTGlRyb3kgQ2VydGlmaWNhdGUgQXV0 -aG9yaXR5MRAwDgYDVQQDEwdUcm95IENBMB4XDTE3MDkxNTIzNTY0MFoXDTE3MTAx -NjAwMTE0MFowRTELMAkGA1UEBhMCVVMxDTALBgNVBAoTBE9hdGgxDzANBgNVBAsT -BkF0aGVuejEWMBQGA1UEAxMNYXRoZW56LmhvY2tleTCCASIwDQYJKoZIhvcNAQEB -BQADggEPADCCAQoCggEBAOlLHnTor/hKxGrEdQvwZiMboU2mggoZl+gM1sld+rrM -lrUpJLCb5Tc34Wib3RJee7sxF7LQCkJe5ljsI0eXhzfGbn1hw34kk+VeNN8ns1DB -kVMExDBzCEar/a85QhHenvxljIdyKaELzNSN1/kONeyyTe8zPeQNeMcnvKqw0uwc -XKv/jtyzJj1T8/Xi2K54BnUL7qghwIxNJsevb1noF38/06JdNOQbaswe+hnvMnxr -i/WYUH85j62edHu2nEB8XoIRHWw5pag/xrALeaxmXfiJPabepQVpDJmTyRiWWCmF -0n4c4ZAzW+wEeUIuTpSKMI3NhFqNny3aLWi1UMYG5ucCAwEAAaOBwTCBvjAOBgNV -HQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMAwGA1Ud -EwEB/wQCMAAwfwYDVR0RBHgwdoItaG9ja2V5LmF0aGVuei51cy13ZXN0LTIuYXdz -LmF0aGVuei5vYXRoLmNsb3VkgkVpLTAzZDFhZTcwMzVmOTMxYTkwLmluc3RhbmNl -aWQuYXRoZW56LnVzLXdlc3QtMi5hd3MuYXRoZW56Lm9hdGguY2xvdWQwDQYJKoZI -hvcNAQELBQADggEBAJfln2+V4Rj/QDstxZXxCwnc0bbbI0NeVEUstLeM9+nRiLz6 -lzGcNsvlKMTBD6haDz9qUvPvZmYa8mqkPtWsDaB0p1ztbskdgQODFlnAl9O1LqlW -u48aU+5l/MCqFio0pdCuiYBEW0CACO+wXTOsWeE0jOPQFqwi3PT60w5y/qsIjq+0 -kNbVXjeP/8LEdlGAz8h0k80zNEM6MQ6wHw6t2cI0SQwr6xJyozPhUQdCzcwJ99OH -+WYipC9fPBrz/WSEN3ZkBfyXtN/uChYICPpMnTQD85ZbmGM+n2oaqbYM60BdGWOt -0WgY3NI56U2bMOfjoXMAIANYuPZRReO5wzjlkz0= ------END CERTIFICATE----- diff --git a/libs/go/sia/aws/agent/devel/data/key.pem b/libs/go/sia/aws/agent/devel/data/key.pem deleted file mode 100644 index b7fa17387d4..00000000000 --- a/libs/go/sia/aws/agent/devel/data/key.pem +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEA6UsedOiv+ErEasR1C/BmIxuhTaaCChmX6AzWyV36usyWtSkk -sJvlNzfhaJvdEl57uzEXstAKQl7mWOwjR5eHN8ZufWHDfiST5V403yezUMGRUwTE -MHMIRqv9rzlCEd6e/GWMh3IpoQvM1I3X+Q417LJN7zM95A14xye8qrDS7Bxcq/+O -3LMmPVPz9eLYrngGdQvuqCHAjE0mx69vWegXfz/Tol005BtqzB76Ge8yfGuL9ZhQ -fzmPrZ50e7acQHxeghEdbDmlqD/GsAt5rGZd+Ik9pt6lBWkMmZPJGJZYKYXSfhzh -kDNb7AR5Qi5OlIowjc2EWo2fLdotaLVQxgbm5wIDAQABAoIBAD/PwkbsFqXtnYgu -sG1Rlj5oIljhAJTOp1Rbnqx5vkk2CMsIs/ZyzeGqsUcxyuhpW6K6LOdGLGg3GP6d -qJC+i8ffyP0WrqhkTOfiOsgHTe7640s39InkDRF3ne491SqaIBadmDC8M1LPrXk+ -SyLeljVmGBcjhvxICw8+eUafEzJtkd9laxg1vBNunCeWXbIPEN+tUWLkuQ+mTzfQ -MobGBIKf8Etsw6qIFrwnLCvutRQo2USF++uwF9IwtGktg+M5DE+IJX8yOugQd2Kz -t95p1K1BUk15g7KI2XS8M5EeXvp8XmldpV8Rus3fnRSzfJK1tu2bdfjPOHtanhet -HdAg6pECgYEA97eLgFWIz/pAH651PQic7fyAEkSSAas+LBDqzBrJ4MiejJge1HRo -Y68hUGXB1AcljMBBRjSxVVpZCv63FHMrwqzhiOd2ndy528sY33QHnATnbOdBbrj4 -Ew1ZgKFHZns34m/3XTE71HS3j8BL+ES+hyAJTd4w671NWcs3HUFbew0CgYEA8Rgb -0ne66b0kHYmR651yoavl2BoDXV0VR3U6G5em68lU49M/N82HT18OSAGnK7+XqdiN -nZbguDGA5dkrGcVrHuxmbsw6PweKYY1ceR14GuC5YqGR0+b7VB+VTjApYkJvg37S -gXnbQTWheF6LhPqeteuEz76xvV74BvcaU+O93MMCgYEAmkLm0KhzZnDE9fXCdJuk -fl+7saSZ+AgX04FFdo1IIn9MnOkuaceEKm+pI1P6/Hrm21vuSjYOKMT2pm4wvL9s -BPN8D7F0oKIP69vyRVUQWAyFwb/Rc44kjljF3+CPgjZBevWW6aX7SDbXCOILbTQC -IkvE/4TamjNss/pk/AbzXRECgYB2QN38qK9sUFJzjnOdPrfoJplxGqlF1Q9H4m7i -88py4miZ6paad1wECVrG5NCrO6lXLJmhj0yf6+AOXSuv005Md6VyeQekvL0aRizy -Vwr/G/SyNQ+DAUuLIoaoOCVERdPOipkT2sI5ROXzVWRXkFniXyfggedKPFepivBF -73HD4wKBgFgxOmKSnzPU43t92ES4WzVXqAnQOq7JabBJn/bJSizRPZnqjgKw4Nlg -cowfx1bF0j04wcNBRaklhn4KDDHeknY1rNzFOPsRFPBO5QHk1iUmKJ4yEja77ti1 -K+6UNC57qJjRqWd/wC1fI+SVPXhaa+iLkG3eFKmwLs+qhKYGWtNA ------END RSA PRIVATE KEY----- diff --git a/libs/go/sia/aws/agent/devel/ztsmock/zts.go b/libs/go/sia/aws/agent/devel/ztsmock/zts.go deleted file mode 100644 index 62b622e39a7..00000000000 --- a/libs/go/sia/aws/agent/devel/ztsmock/zts.go +++ /dev/null @@ -1,356 +0,0 @@ -// -// Copyright The Athenz Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package ztsmock - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "encoding/pem" - "fmt" - "io" - "log" - "math/big" - "net" - "net/http" - "time" - - "github.com/AthenZ/athenz/clients/go/zts" - "github.com/gorilla/mux" -) - -var caKeyStr string -var caCertStr string - -func SetupCA() (string, string) { - - key, err := generateKeyPair() - if err != nil { - log.Fatalf("Cannot generate private key: %v\n", err) - } - - //create self-signed cert - country := "US" - province := "Oregon" - locality := "Stafford" - org := "Troy" - unit := "Troy Certificate Authority" - name := "Troy CA" - certPem, err := createCACert(key, country, locality, province, org, unit, name, nil, nil) - if err != nil { - log.Fatalf("Cannot create CA Cert: %v\n", err) - } - - return privatePem(key), certPem -} - -func StartZtsServer(endPoint string) { - router := mux.NewRouter() - - router.HandleFunc("/zts/v1/instance", func(w http.ResponseWriter, r *http.Request) { - log.Println("/instance is called") - - body, err := io.ReadAll(r.Body) - if err != nil { - log.Fatalln("Could not read the body") - } - - var data *zts.InstanceRegisterInformation - err = json.Unmarshal(body, &data) - if err != nil { - log.Fatalln("Could not parse the body into zts.InstanceRegisterInformation") - } - - caKey, err := privateKeyFromPem(caKeyStr) - if err != nil { - log.Fatalln("Could not generate caKey from string") - } - - caCert, err := certFromPEM(caCertStr) - if err != nil { - log.Fatalln("Could not generate caCert from string") - } - - service := fmt.Sprintf("%s.%s", data.Domain, data.Service) - cert, err := generateCertInMemory(data.Csr, caKey, caCert, service) - if err != nil { - log.Fatalf("Could not generate cert in memory: %v\n", err) - } - - identity := &zts.InstanceIdentity{ - Provider: data.Provider, - Name: zts.ServiceName(service), - InstanceId: "pod-1234", - X509CertificateSigner: caCertStr, - X509Certificate: cert, - } - identityBytes, err := json.Marshal(identity) - if err == nil { - w.WriteHeader(201) - io.WriteString(w, string(identityBytes)) - log.Println("Successfully processed register instance request") - } - }).Methods("POST") - - router.HandleFunc("/zts/v1/instance/athenz.aws.us-west-2/athenz/hockey/pod-1234", func(w http.ResponseWriter, r *http.Request) { - log.Println("instance refresh handler called") - - body, err := io.ReadAll(r.Body) - if err != nil { - log.Fatalln("Could not read the body") - } - var data *zts.InstanceRefreshInformation - err = json.Unmarshal(body, &data) - if err != nil { - log.Fatalln("Could not parse the body into zts.InstanceRefreshInformation") - } - - caKey, err := privateKeyFromPem(caKeyStr) - if err != nil { - log.Fatalln("Could not generate caKey from string") - } - - caCert, err := certFromPEM(caCertStr) - if err != nil { - log.Fatalln("Could not generate caCert from string") - } - - cert, err := generateCertInMemory(data.Csr, caKey, caCert, "athenz.hockey") - if err != nil { - log.Fatalf("Could not generate cert in memory: %v\n", err) - } - - identity := &zts.InstanceIdentity{ - Provider: "athenz.aws.us-west-2", - Name: zts.ServiceName("athenz.hockey"), - InstanceId: "pod-1234", - X509CertificateSigner: caCertStr, - X509Certificate: cert, - } - identityBytes, err := json.Marshal(identity) - if err == nil { - io.WriteString(w, string(identityBytes)) - log.Println("Successfully processed refresh instance request") - } - }).Methods("POST") - - router.HandleFunc("/zts/v1/rolecert", func(w http.ResponseWriter, r *http.Request) { - log.Println("role certificate handler called") - - body, err := io.ReadAll(r.Body) - if err != nil { - log.Fatalln("Could not read the body") - } - var data *zts.RoleCertificateRequest - err = json.Unmarshal(body, &data) - if err != nil { - log.Fatalln("Could not parse the body into zts.RoleCertificateRequest") - } - - caKey, err := privateKeyFromPem(caKeyStr) - if err != nil { - log.Fatalln("Could not generate caKey from string") - } - - caCert, err := certFromPEM(caCertStr) - if err != nil { - log.Fatalln("Could not generate caCert from string") - } - - cert, err := generateCertInMemory(data.Csr, caKey, caCert, "athenz:role.writers") - if err != nil { - log.Fatalf("Could not generate cert in memory: %v\n", err) - } - - identity := &zts.RoleCertificate{ - X509Certificate: cert, - } - identityBytes, err := json.Marshal(identity) - if err == nil { - io.WriteString(w, string(identityBytes)) - log.Println("Successfully processed role certificate request") - } - }).Methods("POST") - - err := http.ListenAndServe(endPoint, router) - if err != nil { - log.Fatal("ListenAndServe: ", err) - } - -} - -func init() { - caKeyStr, caCertStr = SetupCA() -} - -func generateKeyPair() (*rsa.PrivateKey, error) { - return rsa.GenerateKey(rand.Reader, 2048) -} - -func createCACert(key *rsa.PrivateKey, country, locality, province, org, unit, cn string, hosts []string, ips []net.IP) (string, error) { - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return "", err - } - algo := x509.SHA1WithRSA //for rsa - notBefore := time.Now() - validFor := 365 * 24 * time.Hour - notAfter := notBefore.Add(validFor) - subj := pkix.Name{ - CommonName: cn, - Country: []string{country}, - Locality: []string{locality}, - Province: []string{province}, - Organization: []string{org}, - OrganizationalUnit: []string{unit}, - } - - template := &x509.Certificate{ - Subject: subj, - SerialNumber: serialNumber, - PublicKeyAlgorithm: x509.RSA, - SignatureAlgorithm: algo, - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - if hosts != nil { - template.DNSNames = hosts - } - if ips != nil { - template.IPAddresses = ips - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) - if err != nil { - return "", err - } - certOut := bytes.NewBuffer(make([]byte, 0)) - err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) - if err != nil { - return "", fmt.Errorf("Cannot encode Cert to PEM: %v", err) - } - return certOut.String(), nil -} - -func privatePemBytes(privateKey *rsa.PrivateKey) []byte { - privatePem := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)} - privateBytes := pem.EncodeToMemory(privatePem) - return privateBytes -} - -func privatePem(privateKey *rsa.PrivateKey) string { - return string(privatePemBytes(privateKey)) -} - -func privateKeyFromPemBytes(pemBytes []byte) (*rsa.PrivateKey, error) { - block, _ := pem.Decode(pemBytes) - if block == nil { - return nil, fmt.Errorf("no PEM block found") - } - return x509.ParsePKCS1PrivateKey(block.Bytes) -} - -func privateKeyFromPem(pem string) (*rsa.PrivateKey, error) { - return privateKeyFromPemBytes([]byte(pem)) -} - -func certFromPEM(pemString string) (*x509.Certificate, error) { - return certFromPEMBytes([]byte(pemString)) -} - -func certFromPEMBytes(pemBytes []byte) (*x509.Certificate, error) { - var derBytes []byte - block, _ := pem.Decode(pemBytes) - if block == nil { - return nil, fmt.Errorf("Cannot parse cert (empty pem)") - } - derBytes = block.Bytes - cert, err := x509.ParseCertificate(derBytes) - if err != nil { - return nil, err - } - return cert, nil -} - -func generateCertInMemory(csrPem string, caKey *rsa.PrivateKey, caCert *x509.Certificate, cn string) (string, error) { - csr, err := decodeCSR(csrPem) - if err != nil { - return "", err - } - if cn != "" && cn != csr.Subject.CommonName { - return "", fmt.Errorf("CSR common name (%s) doesn't match expected common name (%s)", csr.Subject.CommonName, cn) - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - algo := x509.SHA256WithRSA - now := time.Now() - tolerance := 15 * time.Minute // to account for time imprecision across machines - notBefore := now.Add(-tolerance) - validFor := 30 * 24 * time.Hour //30 day lifetime while debugging - notAfter := notBefore.Add(validFor + tolerance) - template := &x509.Certificate{ - Subject: csr.Subject, - SerialNumber: serialNumber, - PublicKeyAlgorithm: csr.PublicKeyAlgorithm, - SignatureAlgorithm: algo, - NotBefore: notBefore, - NotAfter: notAfter, - DNSNames: csr.DNSNames, - IPAddresses: csr.IPAddresses, - EmailAddresses: csr.EmailAddresses, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - cert, err := x509.CreateCertificate(rand.Reader, template, caCert, csr.PublicKey, caKey) - if err != nil { - return "", err - } - - certOut := bytes.NewBuffer(make([]byte, 0)) - err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: cert}) - if err != nil { - return "", fmt.Errorf("Cannot encode Cert to PEM: %v", err) - } - return certOut.String(), nil -} - -func decodeCSR(csr string) (*x509.CertificateRequest, error) { - var derBytes []byte - block, _ := pem.Decode([]byte(csr)) - if block == nil { - return nil, fmt.Errorf("Cannot parse CSR (empty pem)") - } - derBytes = block.Bytes - req, err := x509.ParseCertificateRequest(derBytes) - if err != nil { - return nil, err - } - err = req.CheckSignature() - if err != nil { - return nil, err - } - return req, nil -} diff --git a/libs/go/sia/aws/attestation/attestation.go b/libs/go/sia/aws/attestation/attestation.go index 90e64450a42..dbe29106a3e 100644 --- a/libs/go/sia/aws/attestation/attestation.go +++ b/libs/go/sia/aws/attestation/attestation.go @@ -25,7 +25,6 @@ import ( "strings" "github.com/AthenZ/athenz/libs/go/sia/aws/stssession" - sc "github.com/AthenZ/athenz/libs/go/sia/config" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -42,27 +41,32 @@ type AttestationData struct { // New creates a new AttestationData with values fed to it and from the result of STS Assume Role. // This requires an identity document along with its signature. The aws account and region will // be extracted from the identity document. -func New(opts *sc.Options, service string) (*AttestationData, error) { - commonName := fmt.Sprintf("%s.%s", opts.Domain, service) +func New(domain, service, region, account, ec2Document, ec2Signature string, useRegionalSTS, omitDomain bool) (string, error) { + commonName := fmt.Sprintf("%s.%s", domain, service) var role string - if opts.OmitDomain { + if omitDomain { role = service } else { role = commonName } - tok, err := getSTSToken(opts.UseRegionalSTS, opts.Region, opts.Account, role) + tok, err := getSTSToken(useRegionalSTS, region, account, role) if err != nil { - return nil, err + return "", err } - return &AttestationData{ + data, err := json.Marshal(&AttestationData{ Role: role, CommonName: commonName, - Document: opts.EC2Document, - Signature: opts.EC2Signature, + Document: ec2Document, + Signature: ec2Signature, Access: *tok.Credentials.AccessKeyId, Secret: *tok.Credentials.SecretAccessKey, Token: *tok.Credentials.SessionToken, - }, nil + }) + if err != nil { + return "", err + } + + return string(data), nil } func getSTSToken(useRegionalSTS bool, region, account, role string) (*sts.AssumeRoleOutput, error) { @@ -113,16 +117,3 @@ func GetECSTaskId() string { } return taskId } - -// GetAttestationData fetches attestation data for all the services mentioned in the config file -func GetAttestationData(opts *sc.Options) ([]*AttestationData, error) { - data := []*AttestationData{} - for _, svc := range opts.Services { - a, err := New(opts, svc.Name) - if err != nil { - return nil, err - } - data = append(data, a) - } - return data, nil -} diff --git a/libs/go/sia/config/config.go b/libs/go/sia/config/config.go index 486ed3ce574..5a9c47ab189 100644 --- a/libs/go/sia/config/config.go +++ b/libs/go/sia/config/config.go @@ -177,7 +177,6 @@ type Options struct { AthenzCACertFile string //filename to store Athenz CA certs ZTSCACertFile string //filename for CA certs when communicating with ZTS ZTSServerName string //ZTS server name, if necessary for tls - ZTSAWSDomains []string //list of domain prefixes for sanDNS entries GenerateRoleKey bool //option to generate a separate key for role certificates RotateKey bool //rotate the private key when refreshing certificates BackupDir string //backup directory for key/cert rotation diff --git a/libs/go/sia/host/provider/provider.go b/libs/go/sia/host/provider/provider.go index 1773ce7e260..980c67f2856 100644 --- a/libs/go/sia/host/provider/provider.go +++ b/libs/go/sia/host/provider/provider.go @@ -27,6 +27,19 @@ import ( "github.com/AthenZ/athenz/libs/go/sia/host/signature" ) +type AttestationRequest struct { + MetaEndPoint string //meta data service endpoint + Domain string //name of the domain for the identity + Service string //name of the service for the identity + ZTSUrl string //the ZTS to contact + OmitDomain bool //attestation role only includes service name + UseRegionalSTS bool //use regional sts endpoint + Account string //name of the account + Region string //region name + EC2Document string //EC2 instance identity document (for AWS only) + EC2Signature string //EC2 instance identity document pkcs7 signature (for AWS only) +} + // Provider is the interface which wraps various Providers known to ZTS // It has methods for providing attestationdata depending on provider type // and generating sub-parts of DN to be including in the CSR and San DNS and URI entries @@ -65,7 +78,7 @@ type Provider interface { GetSuffixes() []string // CloudAttestationData gets the attestation data to prove the identity from metadata of the respective cloud - CloudAttestationData(string, string, string) (string, error) + CloudAttestationData(*AttestationRequest) (string, error) // GetAccountDomainServiceFromMeta gets the account, domain and service info from the respective cloud GetAccountDomainServiceFromMeta(string) (string, string, string, error) diff --git a/libs/go/sia/options/mockawsprovider.go b/libs/go/sia/options/mockawsprovider.go index d87df433a85..4a27363bc3d 100644 --- a/libs/go/sia/options/mockawsprovider.go +++ b/libs/go/sia/options/mockawsprovider.go @@ -11,6 +11,7 @@ import ( "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" ) @@ -65,7 +66,7 @@ func (tp MockAWSProvider) GetSuffixes() []string { return []string{} } -func (tp MockAWSProvider) CloudAttestationData(string, string, string) (string, error) { +func (tp MockAWSProvider) CloudAttestationData(*provider.AttestationRequest) (string, error) { a, _ := json.Marshal(&attestation.AttestationData{ Role: "athenz.hockey", }) diff --git a/libs/go/sia/options/mockgcpprovider.go b/libs/go/sia/options/mockgcpprovider.go index 10d180d308d..ea6e8be2464 100644 --- a/libs/go/sia/options/mockgcpprovider.go +++ b/libs/go/sia/options/mockgcpprovider.go @@ -11,6 +11,7 @@ import ( "github.com/AthenZ/athenz/libs/go/sia/gcp/attestation" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" ) @@ -65,7 +66,7 @@ func (tp MockGCPProvider) GetSuffixes() []string { return []string{} } -func (tp MockGCPProvider) CloudAttestationData(base, svc, ztsServerName string) (string, error) { +func (tp MockGCPProvider) CloudAttestationData(*provider.AttestationRequest) (string, error) { a, _ := json.Marshal(&attestation.GoogleAttestationData{ IdentityToken: "abc", }) diff --git a/libs/go/sia/util/util.go b/libs/go/sia/util/util.go index f75aa048b91..ad282e14839 100644 --- a/libs/go/sia/util/util.go +++ b/libs/go/sia/util/util.go @@ -402,7 +402,7 @@ func GenerateRoleCertCSR(key *rsa.PrivateKey, options *RoleCertReqOptions) (stri return GenerateX509CSR(key, csrDetails) } -func GenerateSSHHostCSR(sshPubKeyFile string, domain, service, ip string, ztsAwsDomains []string) (string, error) { +func GenerateSSHHostCSR(sshPubKeyFile string, domain, service, ip string, ztsCloudDomains []string) (string, error) { log.Println("Generating SSH Host Certificate CSR...") @@ -415,7 +415,7 @@ func GenerateSSHHostCSR(sshPubKeyFile string, domain, service, ip string, ztsAws transId := fmt.Sprintf("%x", time.Now().Unix()) hyphenDomain := strings.Replace(domain, ".", "-", -1) principals := []string{} - for _, ztsDomain := range ztsAwsDomains { + for _, ztsDomain := range ztsCloudDomains { host := fmt.Sprintf("%s.%s.%s", service, hyphenDomain, ztsDomain) principals = append(principals, host) } @@ -434,7 +434,7 @@ func GenerateSSHHostCSR(sshPubKeyFile string, domain, service, ip string, ztsAws return string(csr), err } -func GenerateSSHHostRequest(sshPubKeyFile string, domain, service, hostname, ip, instanceId, sshPrincipals string, ztsAwsDomains []string) (*zts.SSHCertRequest, error) { +func GenerateSSHHostRequest(sshPubKeyFile string, domain, service, hostname, ip, instanceId, sshPrincipals string, ztsCloudDomains []string) (*zts.SSHCertRequest, error) { log.Println("Generating SSH Host Certificate Request...") @@ -455,7 +455,7 @@ func GenerateSSHHostRequest(sshPubKeyFile string, domain, service, hostname, ip, if ip != "" { principals = append(principals, ip) } - for _, ztsDomain := range ztsAwsDomains { + for _, ztsDomain := range ztsCloudDomains { host := fmt.Sprintf("%s.%s.%s", service, hyphenDomain, ztsDomain) principals = append(principals, host) } diff --git a/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManager.java b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManager.java new file mode 100644 index 00000000000..1ab8fdffaa5 --- /dev/null +++ b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManager.java @@ -0,0 +1,89 @@ +/* + * + * * Copyright The Athenz Authors + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.yahoo.athenz.common.server.spiffe; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +public class SpiffeUriManager { + + private static final Logger LOGGER = LoggerFactory.getLogger(SpiffeUriManager.class); + + public static final String ZTS_PROP_SPIFFE_URI_VALIDATOR_CLASSES = "athenz.zts.spiffe_uri_validator_classes"; + public static final String ZTS_DEFAULT_SPIFFE_URI_VALIDATOR_CLASSES = "com.yahoo.athenz.common.server.spiffe.impl.SpiffeUriTrustDomain,com.yahoo.athenz.common.server.spiffe.impl.SpiffeUriBasic"; + + private final List validators; + + public SpiffeUriManager() { + + final String validatorClasses = System.getProperty(ZTS_PROP_SPIFFE_URI_VALIDATOR_CLASSES, + ZTS_DEFAULT_SPIFFE_URI_VALIDATOR_CLASSES); + + validators = new ArrayList<>(); + String[] validatorClassList = validatorClasses.split(","); + for (String validatorClass : validatorClassList) { + SpiffeUriValidator validator = getValidator(validatorClass.trim()); + if (validator == null) { + throw new IllegalArgumentException("Invalid spiffe uri validator: " + validatorClass); + } + validators.add(validator); + } + } + + SpiffeUriValidator getValidator(String className) { + + LOGGER.debug("Loading spiffe uri validator {}...", className); + + SpiffeUriValidator validator; + try { + validator = (SpiffeUriValidator) Class.forName(className).getDeclaredConstructor().newInstance(); + } catch (Exception ex) { + LOGGER.error("Invalid validator class: {}", className, ex); + return null; + } + return validator; + } + + public boolean validateServiceCertUri(final String spiffeUri, final String domainName, final String serviceName, + final String namespace) { + + for (SpiffeUriValidator validator : validators) { + if (validator.validateServiceCertUri(spiffeUri, domainName, serviceName, namespace)) { + return true; + } + } + LOGGER.error("unable to validate service spiffe uri: {}, domainName: {}, serviceName: {}, namespace: {}", + spiffeUri, domainName, serviceName, namespace); + return false; + } + + public boolean validateRoleCertUri(final String spiffeUri, final String domainName, final String roleName) { + for (SpiffeUriValidator validator : validators) { + if (validator.validateRoleCertUri(spiffeUri, domainName, roleName)) { + return true; + } + } + LOGGER.error("unable to validate role spiffe uri: {}, domainName: {}, roleName: {}", + spiffeUri, domainName, roleName); + return false; + } +} diff --git a/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidator.java b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidator.java new file mode 100644 index 00000000000..9512a530345 --- /dev/null +++ b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidator.java @@ -0,0 +1,45 @@ +/* + * + * * Copyright The Athenz Authors + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.yahoo.athenz.common.server.spiffe; + +/** + * An interface that allows system administrators to validate SPIFFE URIs + * based on their own requirements. + */ +public interface SpiffeUriValidator { + + /** + * Validate the SPIFFE URI for service identity certificates based on the system requirements. + * @param spiffeUri the SPIFFE URI to be validated (e.g. spiffe://athenz.domain/sa/service) + * @param domainName the domain name of the service + * @param serviceName the service name + * @param namespace the namespace of the service (typically a Kubernetes namespace) + * @return true if the SPIFFE URI is valid, false otherwise + */ + boolean validateServiceCertUri(final String spiffeUri, final String domainName, final String serviceName, final String namespace); + + /** + * Validate the SPIFFE URI for rike certificates based on the system requirements. + * @param spiffeUri the SPIFFE URI to be validated (e.g. spiffe://athenz.domain/ra/writers) + * @param domainName the domain name of the service + * @param roleName the role name + * @return true if the SPIFFE URI is valid, false otherwise + */ + boolean validateRoleCertUri(final String spiffeUri, final String domainName, final String roleName); +} diff --git a/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasic.java b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasic.java new file mode 100644 index 00000000000..7ec1ba18dc4 --- /dev/null +++ b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasic.java @@ -0,0 +1,52 @@ +/* + * + * * Copyright The Athenz Authors + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.yahoo.athenz.common.server.spiffe.impl; + +import com.yahoo.athenz.common.server.spiffe.SpiffeUriValidator; + +/** + * Basic implementation of SpiffeUriValidator interface. This class validates the SPIFFE URI + * with the following formats: + * Service Cert URI: spiffe:///sa/ + * Example: spiffe://athenz/sa/api + * Role Cert URI: spiffe:///ra/ + * Example: spiffe://athenz/ra/readers + */ +public class SpiffeUriBasic implements SpiffeUriValidator { + + /** + * Supported Service Cert URI: spiffe:///sa/ + * Example: spiffe://athenz/sa/api + */ + @Override + public boolean validateServiceCertUri(String spiffeUri, String domainName, String serviceName, String namespace) { + final String reqUri = String.format("spiffe://%s/sa/%s", domainName, serviceName); + return reqUri.equalsIgnoreCase(spiffeUri); + } + + /** + * Supported Role Cert URI: spiffe:///ra/ + * Example: spiffe://athenz/ra/readers + */ + @Override + public boolean validateRoleCertUri(String spiffeUri, String domainName, String roleName) { + final String reqUri = String.format("spiffe://%s/ra/%s", domainName, roleName); + return reqUri.equalsIgnoreCase(spiffeUri); + } +} diff --git a/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomain.java b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomain.java new file mode 100644 index 00000000000..c992603ff3f --- /dev/null +++ b/libs/java/server_common/src/main/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomain.java @@ -0,0 +1,61 @@ +/* + * + * * Copyright The Athenz Authors + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.yahoo.athenz.common.server.spiffe.impl; + +import com.yahoo.athenz.common.server.spiffe.SpiffeUriValidator; +import org.eclipse.jetty.util.StringUtil; + +/** + * Trust Domain implementation of SpiffeUriValidator interface. This class validates the SPIFFE URI + * with the following formats: + * Service Cert URI: spiffe:///ns//sa/. + * Example: spiffe://athenz.io/ns/prod/sa/athenz.api + * Role Cert URI: spiffe:///ns//ra/ + * Example: spiffe://athenz.io/ns/athenz/ra/readers + */ +public class SpiffeUriTrustDomain implements SpiffeUriValidator { + + private static final String SPIFFE_DEFAULT_NAMESPACE = "default"; + + private static final String SPIFFE_PROP_TRUST_DOMAIN = "athenz.zts.spiffe_trust_domain"; + private static final String SPIFFE_TRUST_DOMAIN = System.getProperty(SPIFFE_PROP_TRUST_DOMAIN, "athenz.io"); + + /** + * Service Cert URI: spiffe:///ns//sa/. + * Example: spiffe://athenz.io/ns/prod/sa/athenz.api + */ + @Override + public boolean validateServiceCertUri(String spiffeUri, String domainName, String serviceName, String namespace) { + final String ns = StringUtil.isEmpty(namespace) ? SPIFFE_DEFAULT_NAMESPACE : namespace; + final String reqUri = String.format("spiffe://%s/ns/%s/sa/%s.%s", SPIFFE_TRUST_DOMAIN, + ns, domainName, serviceName); + return reqUri.equalsIgnoreCase(spiffeUri); + } + + /** + * Role Cert URI: spiffe:///ns//ra/ + * Example: spiffe://athenz.io/ns/athenz/ra/readers + */ + @Override + public boolean validateRoleCertUri(String spiffeUri, String domainName, String roleName) { + final String reqUri = String.format("spiffe://%s/ns/%s/ra/%s", SPIFFE_TRUST_DOMAIN, + domainName, roleName); + return reqUri.equalsIgnoreCase(spiffeUri); + } +} diff --git a/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManagerTest.java b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManagerTest.java new file mode 100644 index 00000000000..8b2197eeb5e --- /dev/null +++ b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriManagerTest.java @@ -0,0 +1,87 @@ +/* + * Copyright The Athenz Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yahoo.athenz.common.server.spiffe; + +import com.yahoo.athenz.common.server.spiffe.impl.SpiffeUriTrustDomain; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.*; + +public class SpiffeUriManagerTest { + + @BeforeClass + public void setup() { + System.setProperty("athenz.zts.spiffe_trust_domain", "spiffe.athenz.io"); + } + + @Test + public void testValidateServiceCertUriDefaultClasses() { + + System.clearProperty("athenz.zts.spiffe_uri_validator_classes"); + SpiffeUriManager manager = new SpiffeUriManager(); + + assertTrue(manager.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "api", null)); + assertTrue(manager.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "api", "default")); + + assertFalse(manager.validateServiceCertUri("spiffe://athenz/sa/api", "athenz.prod", "api", "default")); + assertFalse(manager.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "backend", "default")); + + assertTrue(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", null)); + assertTrue(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", "default")); + assertTrue(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "prod")); + + assertFalse(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", "prod")); + assertFalse(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.backend", "athenz", "api", "default")); + assertFalse(manager.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/prod/sa/athenz.api", "athenz.prod", "api", "prod")); + + assertFalse(manager.validateServiceCertUri("spiffe://athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "prod")); + } + + @Test + public void testValidateRoleCertUriDefaultClasses() { + + System.clearProperty("athenz.zts.spiffe_uri_validator_classes"); + SpiffeUriManager manager = new SpiffeUriManager(); + + assertTrue(manager.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz", "readers")); + assertTrue(manager.validateRoleCertUri("spiffe://athenz/ra/writers", "athenz", "writers")); + + assertFalse(manager.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz.prod", "readers")); + assertFalse(manager.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz", "writers")); + + assertTrue(manager.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz", "readers")); + + assertFalse(manager.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz", "writers")); + assertFalse(manager.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz.prod", "readers")); + + assertFalse(manager.validateRoleCertUri("spiffe://athenz.io/ns/athenz/ra/readers", "athenz", "readers")); + } + + @Test + public void testValidateInvalidClass() { + + System.setProperty("athenz.zts.spiffe_uri_validator_classes", "com.yahoo.athenz.common.server.spiffe.impl.InvalidClass"); + try { + new SpiffeUriManager(); + fail(); + } catch (IllegalArgumentException ex) { + assertTrue(ex.getMessage().contains("Invalid spiffe uri validator: com.yahoo.athenz.common.server.spiffe.impl.InvalidClass")); + } + System.clearProperty("athenz.zts.spiffe_uri_validator_classes"); + } +} diff --git a/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidatorTest.java b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidatorTest.java new file mode 100644 index 00000000000..be98dbb77b2 --- /dev/null +++ b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/SpiffeUriValidatorTest.java @@ -0,0 +1,53 @@ +/* + * Copyright The Athenz Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yahoo.athenz.common.server.spiffe; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class SpiffeUriValidatorTest { + + @Test + public void testValidate() { + + final String trustDomain = "athenz.io"; + + SpiffeUriValidator validator = new SpiffeUriValidator() { + @Override + public boolean validateServiceCertUri(String spiffeUri, String domainName, String serviceName, String namespace) { + final String expectedUri = String.format("spiffe://%s/ns/%s/sa/%s.%s", trustDomain, namespace, + domainName, serviceName); + return spiffeUri.equals(expectedUri); + } + + @Override + public boolean validateRoleCertUri(String spiffeUri, String domainName, String roleName) { + final String expectedUri = String.format("spiffe://%s/ns/%s/ra/%s", trustDomain, + domainName, roleName); + return spiffeUri.equals(expectedUri); + } + }; + + assertTrue(validator.validateServiceCertUri("spiffe://athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "prod")); + assertFalse(validator.validateServiceCertUri("spiffe://athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "dev")); + + assertTrue(validator.validateRoleCertUri("spiffe://athenz.io/ns/athenz/ra/readers", "athenz", "readers")); + assertFalse(validator.validateRoleCertUri("spiffe://athenz.io/ns/athenz/ra/readers", "athenz", "writers")); + } +} diff --git a/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasicTest.java b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasicTest.java new file mode 100644 index 00000000000..198c9106229 --- /dev/null +++ b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriBasicTest.java @@ -0,0 +1,47 @@ +/* + * Copyright The Athenz Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yahoo.athenz.common.server.spiffe.impl; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class SpiffeUriBasicTest { + + @Test + public void testValidateServiceCertUri() { + SpiffeUriBasic validator = new SpiffeUriBasic(); + + assertTrue(validator.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "api", null)); + assertTrue(validator.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "api", "default")); + + assertFalse(validator.validateServiceCertUri("spiffe://athenz/sa/api", "athenz.prod", "api", "default")); + assertFalse(validator.validateServiceCertUri("spiffe://athenz/sa/api", "athenz", "backend", "default")); + } + + @Test + public void testValidateRoleCertUri() { + SpiffeUriBasic validator = new SpiffeUriBasic(); + + assertTrue(validator.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz", "readers")); + assertTrue(validator.validateRoleCertUri("spiffe://athenz/ra/writers", "athenz", "writers")); + + assertFalse(validator.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz.prod", "readers")); + assertFalse(validator.validateRoleCertUri("spiffe://athenz/ra/readers", "athenz", "writers")); + } +} diff --git a/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomainTest.java b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomainTest.java new file mode 100644 index 00000000000..1d4b5af2ae5 --- /dev/null +++ b/libs/java/server_common/src/test/java/com/yahoo/athenz/common/server/spiffe/impl/SpiffeUriTrustDomainTest.java @@ -0,0 +1,58 @@ +/* + * Copyright The Athenz Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.yahoo.athenz.common.server.spiffe.impl; + +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class SpiffeUriTrustDomainTest { + + @BeforeClass + public void setup() { + System.setProperty("athenz.zts.spiffe_trust_domain", "spiffe.athenz.io"); + } + + @Test + public void testValidateServiceCertUri() { + + SpiffeUriTrustDomain validator = new SpiffeUriTrustDomain(); + assertTrue(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", null)); + assertTrue(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", "default")); + assertTrue(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "prod")); + + assertFalse(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.api", "athenz", "api", "prod")); + assertFalse(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/default/sa/athenz.backend", "athenz", "api", "default")); + assertFalse(validator.validateServiceCertUri("spiffe://spiffe.athenz.io/ns/prod/sa/athenz.api", "athenz.prod", "api", "prod")); + + assertFalse(validator.validateServiceCertUri("spiffe://athenz.io/ns/prod/sa/athenz.api", "athenz", "api", "prod")); + } + + @Test + public void testValidateRoleCertUri() { + + SpiffeUriTrustDomain validator = new SpiffeUriTrustDomain(); + assertTrue(validator.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz", "readers")); + + assertFalse(validator.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz", "writers")); + assertFalse(validator.validateRoleCertUri("spiffe://spiffe.athenz.io/ns/athenz/ra/readers", "athenz.prod", "readers")); + + assertFalse(validator.validateRoleCertUri("spiffe://athenz.io/ns/athenz/ra/readers", "athenz", "readers")); + } +} diff --git a/libs/java/server_k8s_common/src/test/java/io/athenz/server/k8s/common/impl/KubernetesPodResolverUtilTest.java b/libs/java/server_k8s_common/src/test/java/io/athenz/server/k8s/common/impl/KubernetesPodResolverUtilTest.java index 3ca48c54347..03abcc08719 100644 --- a/libs/java/server_k8s_common/src/test/java/io/athenz/server/k8s/common/impl/KubernetesPodResolverUtilTest.java +++ b/libs/java/server_k8s_common/src/test/java/io/athenz/server/k8s/common/impl/KubernetesPodResolverUtilTest.java @@ -43,24 +43,16 @@ public void testGetPodSiblings() throws UnknownHostException { @Test public void testGetPodSiblingsEmptyServiceNameException() { String serviceName = ""; - Exception ex = null; try { KubernetesPodResolverUtil.getSiblingPodIPs(serviceName); - } catch (IllegalArgumentException | UnknownHostException e) { - ex = e; - } - if (ex == null) { Assert.fail("expected IllegalArgumentException not thrown"); + } catch (IllegalArgumentException | UnknownHostException ignored) { } - Exception nullEx = null; try { KubernetesPodResolverUtil.getSiblingPodIPs(null); - } catch (IllegalArgumentException | UnknownHostException e) { - nullEx = e; - } - if (nullEx == null) { Assert.fail("expected IllegalArgumentException not thrown"); + } catch (IllegalArgumentException | UnknownHostException ignored) { } } @@ -69,15 +61,18 @@ public void testGetPodSiblingsInvalidHostnameException() { String serviceName = "foo"; MockedStatic inetAddressMock = Mockito.mockStatic(InetAddress.class); inetAddressMock.when(() -> InetAddress.getAllByName(serviceName)).thenThrow(UnknownHostException.class); - Exception ex = null; try { KubernetesPodResolverUtil.getSiblingPodIPs(serviceName); - } catch (IllegalArgumentException | UnknownHostException e) { - ex = e; - } - if (ex == null) { Assert.fail("expected UnknownHostException not thrown"); + } catch (IllegalArgumentException | UnknownHostException ignored) { } inetAddressMock.close(); } + + @Test + public void testConstructor() { + // test to get code coverage to 100% + KubernetesPodResolverUtil util = new KubernetesPodResolverUtil(); + Assert.assertNotNull(util); + } } diff --git a/provider/aws/sia-ec2/cmd/siad/main.go b/provider/aws/sia-ec2/cmd/siad/main.go index 9c373ad5c69..67f2c2ae4ac 100644 --- a/provider/aws/sia-ec2/cmd/siad/main.go +++ b/provider/aws/sia-ec2/cmd/siad/main.go @@ -23,7 +23,7 @@ import ( "os" "strings" - "github.com/AthenZ/athenz/libs/go/sia/aws/agent" + "github.com/AthenZ/athenz/libs/go/sia/agent" "github.com/AthenZ/athenz/libs/go/sia/aws/options" "github.com/AthenZ/athenz/libs/go/sia/ssh/hostkey" "github.com/AthenZ/athenz/libs/go/sia/util" @@ -111,7 +111,7 @@ func main() { opts.PrivateIp = privateIp opts.ZTSCACertFile = *ztsCACert opts.ZTSServerName = *ztsServerName - opts.ZTSAWSDomains = strings.Split(*dnsDomains, ",") + opts.ZTSCloudDomains = strings.Split(*dnsDomains, ",") opts.SpiffeNamespace = "default" provider := sia.EC2Provider{ diff --git a/provider/aws/sia-ec2/provider.go b/provider/aws/sia-ec2/provider.go index 55c9092a6c3..87da48f095a 100644 --- a/provider/aws/sia-ec2/provider.go +++ b/provider/aws/sia-ec2/provider.go @@ -21,12 +21,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "github.com/AthenZ/athenz/libs/go/sia/host/ip" - "github.com/AthenZ/athenz/libs/go/sia/host/signature" - "github.com/AthenZ/athenz/libs/go/sia/host/utils" "log" "net" "net/url" + + "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" + "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" + "github.com/AthenZ/athenz/libs/go/sia/host/signature" + "github.com/AthenZ/athenz/libs/go/sia/host/utils" ) type EC2Provider struct { @@ -80,8 +83,8 @@ func (ec2 EC2Provider) GetSuffixes() []string { return []string{} } -func (ec2 EC2Provider) CloudAttestationData(_, _, _ string) (string, error) { - return "", fmt.Errorf("not implemented") +func (ec2 EC2Provider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return attestation.New(request.Domain, request.Service, request.Region, request.Account, request.EC2Document, request.EC2Signature, request.UseRegionalSTS, request.OmitDomain) } func (ec2 EC2Provider) GetAccountDomainServiceFromMeta(_ string) (string, string, string, error) { diff --git a/provider/aws/sia-eks/cmd/siad/main.go b/provider/aws/sia-eks/cmd/siad/main.go index 15b2f654614..d80b07c9066 100644 --- a/provider/aws/sia-eks/cmd/siad/main.go +++ b/provider/aws/sia-eks/cmd/siad/main.go @@ -23,7 +23,7 @@ import ( "os" "strings" - "github.com/AthenZ/athenz/libs/go/sia/aws/agent" + "github.com/AthenZ/athenz/libs/go/sia/agent" "github.com/AthenZ/athenz/libs/go/sia/aws/meta" "github.com/AthenZ/athenz/libs/go/sia/aws/options" "github.com/AthenZ/athenz/libs/go/sia/host/utils" @@ -90,7 +90,7 @@ func main() { opts.Ssh = false opts.ZTSCACertFile = *ztsCACert opts.ZTSServerName = *ztsServerName - opts.ZTSAWSDomains = strings.Split(*dnsDomains, ",") + opts.ZTSCloudDomains = strings.Split(*dnsDomains, ",") spiffeNamespace, addlSanDNSEntries := utils.GetK8SHostnames("cluster.local", false) opts.SpiffeNamespace = spiffeNamespace if len(addlSanDNSEntries) > 0 { diff --git a/provider/aws/sia-eks/provider.go b/provider/aws/sia-eks/provider.go index 7fa4b775ab8..39de9586206 100644 --- a/provider/aws/sia-eks/provider.go +++ b/provider/aws/sia-eks/provider.go @@ -21,11 +21,14 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "net" + "net/url" + + "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" "github.com/AthenZ/athenz/libs/go/sia/host/utils" - "net" - "net/url" ) type EKSProvider struct { @@ -78,8 +81,8 @@ func (eks EKSProvider) GetSuffixes() []string { return []string{} } -func (eks EKSProvider) CloudAttestationData(_, _, _ string) (string, error) { - return "", fmt.Errorf("not implemented") +func (eks EKSProvider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return attestation.New(request.Domain, request.Service, request.Region, request.Account, request.EC2Document, request.EC2Signature, request.UseRegionalSTS, request.OmitDomain) } func (eks EKSProvider) GetAccountDomainServiceFromMeta(_ string) (string, string, string, error) { diff --git a/provider/aws/sia-fargate/cmd/siad/main.go b/provider/aws/sia-fargate/cmd/siad/main.go index d2024a7d25a..fbe887797de 100644 --- a/provider/aws/sia-fargate/cmd/siad/main.go +++ b/provider/aws/sia-fargate/cmd/siad/main.go @@ -23,7 +23,7 @@ import ( "os" "strings" - "github.com/AthenZ/athenz/libs/go/sia/aws/agent" + "github.com/AthenZ/athenz/libs/go/sia/agent" "github.com/AthenZ/athenz/libs/go/sia/aws/options" "github.com/AthenZ/athenz/provider/aws/sia-fargate" ) @@ -91,7 +91,7 @@ func main() { opts.Ssh = false opts.ZTSCACertFile = *ztsCACert opts.ZTSServerName = *ztsServerName - opts.ZTSAWSDomains = strings.Split(*dnsDomains, ",") + opts.ZTSCloudDomains = strings.Split(*dnsDomains, ",") opts.InstanceId = taskId if *udsPath != "" { diff --git a/provider/aws/sia-fargate/provider.go b/provider/aws/sia-fargate/provider.go index 59f12e0f16c..4827c5ad413 100644 --- a/provider/aws/sia-fargate/provider.go +++ b/provider/aws/sia-fargate/provider.go @@ -21,10 +21,13 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "github.com/AthenZ/athenz/libs/go/sia/host/ip" - "github.com/AthenZ/athenz/libs/go/sia/host/signature" "net" "net/url" + + "github.com/AthenZ/athenz/libs/go/sia/aws/attestation" + "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" + "github.com/AthenZ/athenz/libs/go/sia/host/signature" ) type FargateProvider struct { @@ -77,8 +80,8 @@ func (fargate FargateProvider) GetSuffixes() []string { return []string{} } -func (fargate FargateProvider) CloudAttestationData(_, _, _ string) (string, error) { - return "", fmt.Errorf("not implemented") +func (fargate FargateProvider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return attestation.New(request.Domain, request.Service, request.Region, request.Account, request.EC2Document, request.EC2Signature, request.UseRegionalSTS, request.OmitDomain) } func (fargate FargateProvider) GetAccountDomainServiceFromMeta(_ string) (string, string, string, error) { diff --git a/provider/gcp/sia-gce/provider.go b/provider/gcp/sia-gce/provider.go index 261668268dc..39854f0e237 100644 --- a/provider/gcp/sia-gce/provider.go +++ b/provider/gcp/sia-gce/provider.go @@ -21,13 +21,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "net" + "net/url" + gcpa "github.com/AthenZ/athenz/libs/go/sia/gcp/attestation" "github.com/AthenZ/athenz/libs/go/sia/gcp/meta" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" "github.com/AthenZ/athenz/libs/go/sia/host/utils" - "net" - "net/url" ) type GCEProvider struct { @@ -84,8 +86,8 @@ func (gce GCEProvider) GetSuffixes() []string { return []string{} } -func (gce GCEProvider) CloudAttestationData(base, svc, ztsServerName string) (string, error) { - return gcpa.New(base, svc, ztsServerName) +func (gce GCEProvider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return gcpa.New(request.MetaEndPoint, request.Service, request.ZTSUrl) } func (gce GCEProvider) GetAccountDomainServiceFromMeta(base string) (string, string, string, error) { diff --git a/provider/gcp/sia-gke/provider.go b/provider/gcp/sia-gke/provider.go index b779c61a504..ba536c73cc3 100644 --- a/provider/gcp/sia-gke/provider.go +++ b/provider/gcp/sia-gke/provider.go @@ -21,13 +21,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "net" + "net/url" + gcpa "github.com/AthenZ/athenz/libs/go/sia/gcp/attestation" "github.com/AthenZ/athenz/libs/go/sia/gcp/meta" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" "github.com/AthenZ/athenz/libs/go/sia/host/utils" - "net" - "net/url" ) type GKEProvider struct { @@ -84,8 +86,8 @@ func (gke GKEProvider) GetSuffixes() []string { return []string{} } -func (gke GKEProvider) CloudAttestationData(base, svc, ztsServerName string) (string, error) { - return gcpa.New(base, svc, ztsServerName) +func (gke GKEProvider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return gcpa.New(request.MetaEndPoint, request.Service, request.ZTSUrl) } func (gke GKEProvider) GetAccountDomainServiceFromMeta(base string) (string, string, string, error) { diff --git a/provider/gcp/sia-run/provider.go b/provider/gcp/sia-run/provider.go index 77d5a7d3924..2bf2c2d691b 100644 --- a/provider/gcp/sia-run/provider.go +++ b/provider/gcp/sia-run/provider.go @@ -21,13 +21,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "net" + "net/url" + gcpa "github.com/AthenZ/athenz/libs/go/sia/gcp/attestation" "github.com/AthenZ/athenz/libs/go/sia/gcp/meta" "github.com/AthenZ/athenz/libs/go/sia/host/ip" + "github.com/AthenZ/athenz/libs/go/sia/host/provider" "github.com/AthenZ/athenz/libs/go/sia/host/signature" "github.com/AthenZ/athenz/libs/go/sia/host/utils" - "net" - "net/url" ) type GCPRunProvider struct { @@ -84,8 +86,8 @@ func (gcprun GCPRunProvider) GetSuffixes() []string { return []string{} } -func (gcprun GCPRunProvider) CloudAttestationData(base, svc, ztsServerName string) (string, error) { - return gcpa.New(base, svc, ztsServerName) +func (gcprun GCPRunProvider) CloudAttestationData(request *provider.AttestationRequest) (string, error) { + return gcpa.New(request.MetaEndPoint, request.Service, request.ZTSUrl) } func (gcprun GCPRunProvider) GetAccountDomainServiceFromMeta(base string) (string, string, string, error) { diff --git a/servers/zts/conf/zts.properties b/servers/zts/conf/zts.properties index 7ca45f937e8..85934efb89c 100644 --- a/servers/zts/conf/zts.properties +++ b/servers/zts/conf/zts.properties @@ -810,3 +810,14 @@ athenz.zts.k8s_provider_distribution_validator_factory_class=com.yahoo.athenz.in # as duration * timeunit, and the CertRecordCleaner will run at this defined interval. #athenz.zts.cert_record_cleaner_duration=1 #athenz.zts.cert_record_cleaner_timeunit=day + +# This property specifies a comma separated list of Spiffe URI validator classes. +# The class must implement the com.yahoo.athenz.common.server.spiffe interface from +# the athenz-server-common package. The server will use the classes in the order +# they are specified in the list. By default, the server uses both SpiffeUriTrustDomain +# and SpiffeUriBasic classes. +#athenz.zts.spiffe_uri_validator_classes=com.yahoo.athenz.common.server.spiffe.impl.SpiffeUriTrustDomain,com.yahoo.athenz.common.server.spiffe.impl.SpiffeUriBasic + +# This property specifies the Spiffe trust domain that the server will use to validate +# the Spiffe URI in the request if SpiffeUriTrustDomain validator is enabled. +#athenz.zts.spiffe_trust_domain=athenz.io diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java index 4dbe4934dc7..5ccaea68ead 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java @@ -55,7 +55,6 @@ public final class ZTSConsts { public static final String ZTS_PROP_READ_ONLY_MODE = "athenz.zts.read_only_mode"; public static final String ZTS_PROP_HEALTH_CHECK_PATH = "athenz.zts.health_check_path"; public static final String ZTS_PROP_SERVER_REGION = "athenz.zts.server_region"; - public static final String ZTS_PROP_SPIFFE_TRUST_DOMAIN = "athenz.zts.spiffe_trust_domain"; public static final String ZTS_PROP_AWS_CREDS_CACHE_TIMEOUT = "athenz.zts.aws_creds_cache_timeout"; public static final String ZTS_PROP_AWS_CREDS_INVALID_CACHE_TIMEOUT = "athenz.zts.aws_creds_invalid_cache_timeout"; diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java index d4cb010dd96..93083a3ceb8 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java @@ -45,6 +45,7 @@ import com.yahoo.athenz.common.server.rest.Http; import com.yahoo.athenz.common.server.rest.Http.AuthorityList; import com.yahoo.athenz.common.server.ServerResourceException; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.common.server.ssh.SSHCertRecord; import com.yahoo.athenz.common.server.status.StatusCheckException; import com.yahoo.athenz.common.server.status.StatusChecker; @@ -183,6 +184,7 @@ public class ZTSImpl implements KeyStore, ZTSHandler { private final Object updateJWKMutex = new Object(); protected ExternalCredentialsManager externalCredentialsManager; protected DynamicConfigInteger serviceCertDefaultExpiryMins; + protected SpiffeUriManager spiffeUriManager; private static final String TYPE_DOMAIN_NAME = "DomainName"; private static final String TYPE_SIMPLE_NAME = "SimpleName"; @@ -386,6 +388,10 @@ public ZTSImpl(CloudStore implCloudStore, DataStore implDataStore) { // initialize our external credentials providers externalCredentialsManager = new ExternalCredentialsManager(authorizer); + + // load spiffe uri validators + + spiffeUriManager = new SpiffeUriManager(); } void loadJsonMapper() { @@ -2875,7 +2881,7 @@ public RoleToken postRoleCertificateRequest(ResourceContext ctx, String domainNa X509RoleCertRequest certReq; try { - certReq = new X509RoleCertRequest(req.getCsr()); + certReq = new X509RoleCertRequest(req.getCsr(), spiffeUriManager); } catch (CryptoException ex) { throw requestError("Unable to parse PKCS10 CSR: " + ex.getMessage(), caller, domainName, principalDomain); @@ -3224,7 +3230,7 @@ public RoleCertificate postRoleCertificateRequestExt(ResourceContext ctx, RoleCe X509RoleCertRequest certReq; try { - certReq = new X509RoleCertRequest(req.getCsr()); + certReq = new X509RoleCertRequest(req.getCsr(), spiffeUriManager); } catch (CryptoException ex) { throw requestError("Unable to parse PKCS10 CSR: " + ex.getMessage(), caller, principalDomain, principalDomain); @@ -3769,7 +3775,7 @@ public Response postInstanceRegisterInformation(ResourceContext ctx, InstanceReg X509ServiceCertRequest certReq; try { - certReq = new X509ServiceCertRequest(info.getCsr()); + certReq = new X509ServiceCertRequest(info.getCsr(), spiffeUriManager); } catch (CryptoException ex) { throw requestError("unable to parse PKCS10 CSR: " + ex.getMessage(), caller, domain, principalDomain); @@ -4307,7 +4313,7 @@ InstanceIdentity processProviderX509RefreshRequest(ResourceContext ctx, DomainDa final String principalDomain = principal.getDomain(); X509ServiceCertRequest certReq; try { - certReq = new X509ServiceCertRequest(info.getCsr()); + certReq = new X509ServiceCertRequest(info.getCsr(), spiffeUriManager); } catch (CryptoException ex) { throw requestError("unable to parse PKCS10 CSR", caller, domain, principalDomain); } @@ -4823,7 +4829,7 @@ public Identity postInstanceRefreshRequest(ResourceContext ctx, String domain, X509ServiceCertRequest x509CertReq; try { - x509CertReq = new X509ServiceCertRequest(req.getCsr()); + x509CertReq = new X509ServiceCertRequest(req.getCsr(), spiffeUriManager); } catch (CryptoException ex) { throw requestError("Unable to parse PKCS10 certificate request", caller, domain, principalDomain); diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509CertRequest.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509CertRequest.java index 8f3dabfa03b..506a6334d08 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509CertRequest.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509CertRequest.java @@ -23,6 +23,7 @@ import java.util.regex.Pattern; import com.yahoo.athenz.common.server.dns.HostnameResolver; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.common.utils.X509CertUtils; import com.yahoo.athenz.zts.CertType; import com.yahoo.athenz.zts.ZTSConsts; @@ -39,9 +40,6 @@ public class X509CertRequest { private static final Logger LOGGER = LoggerFactory.getLogger(X509CertRequest.class); private static final Pattern WHITESPACE_PATTERN = Pattern.compile("\\s+"); - protected static final String SPIFFE_NAMESPACE_AGENT = "ns"; - protected static final String SPIFFE_TRUST_DOMAIN = System.getProperty(ZTSConsts.ZTS_PROP_SPIFFE_TRUST_DOMAIN, "athenz.io"); - protected PKCS10CertificationRequest certReq; protected String instanceId; protected String uriHostname; @@ -53,8 +51,9 @@ public class X509CertRequest { protected List providerDnsNames; protected List ipAddresses; protected List uris; + protected SpiffeUriManager spiffeUriManager; - public X509CertRequest(String csr) throws CryptoException { + public X509CertRequest(String csr, SpiffeUriManager spiffeUriManager) throws CryptoException { certReq = Crypto.getPKCS10CertRequest(csr); if (certReq == null) { @@ -106,6 +105,10 @@ public X509CertRequest(String csr) throws CryptoException { if (instanceId == null) { instanceId = X509CertUtils.extractRequestInstanceIdFromDnsNames(dnsNames); } + + // save the spiffe uri manager object + + this.spiffeUriManager = spiffeUriManager; } public PKCS10CertificationRequest getCertReq() { diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509RoleCertRequest.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509RoleCertRequest.java index e171ca28bfd..d89881e95a1 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509RoleCertRequest.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509RoleCertRequest.java @@ -18,6 +18,7 @@ import com.yahoo.athenz.auth.AuthorityConsts; import com.yahoo.athenz.auth.util.Crypto; import com.yahoo.athenz.auth.util.CryptoException; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.common.utils.X509CertUtils; import com.yahoo.athenz.zts.ZTSConsts; import com.yahoo.athenz.zts.utils.ZTSUtils; @@ -32,17 +33,15 @@ public class X509RoleCertRequest extends X509CertRequest { private static final Logger LOGGER = LoggerFactory.getLogger(X509RoleCertRequest.class); - private static final String SPIFFE_ROLE_AGENT = "ra"; - protected String reqRoleName; protected String reqRoleDomain; protected String rolePrincipal; - public X509RoleCertRequest(String csr) throws CryptoException { + public X509RoleCertRequest(String csr, SpiffeUriManager spiffeUriManager) throws CryptoException { // parse the csr request - super(csr); + super(csr, spiffeUriManager); // make sure the CN is a valid role name @@ -230,26 +229,13 @@ public boolean validateIPAddress(X509Certificate cert, final String ip) { public boolean validateSpiffeURI(final String domainName, final String roleName) { - // the expected format are: - // spiffe:///ra/ - // e.g. spiffe://sports/ra/hockey-writers - // spiffe:///ns//ra/ - // e.g. spiffe://athenz.io/ns/sports/ra/hockey-writers + // validate the spiffe uri according to our configured validators if (spiffeUri == null) { return true; } - final String reqUri1 = "spiffe://" + domainName + "/" + SPIFFE_ROLE_AGENT + "/" + roleName; - final String reqUri2 = "spiffe://" + SPIFFE_TRUST_DOMAIN + "/" + SPIFFE_NAMESPACE_AGENT + "/" + - domainName + "/" + SPIFFE_ROLE_AGENT + "/" + roleName; - boolean uriVerified = reqUri1.equalsIgnoreCase(spiffeUri) || reqUri2.equalsIgnoreCase(spiffeUri); - - if (!uriVerified) { - LOGGER.error("validateSpiffeURI: spiffe uri mismatch: {}/{}/{}", spiffeUri, reqUri1, reqUri2); - } - - return uriVerified; + return spiffeUriManager.validateRoleCertUri(spiffeUri, domainName, roleName); } } diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequest.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequest.java index 61737a4889f..353b9594158 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequest.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequest.java @@ -19,20 +19,14 @@ import java.util.Set; import com.yahoo.athenz.auth.util.CryptoException; import com.yahoo.athenz.common.server.dns.HostnameResolver; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.zts.cache.DataCache; import org.eclipse.jetty.util.StringUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class X509ServiceCertRequest extends X509CertRequest { - private static final Logger LOGGER = LoggerFactory.getLogger(X509ServiceCertRequest.class); - - public static final String SPIFFE_SERVICE_AGENT = "sa"; - public static final String SPIFFE_DEFAULT_NAMESPACE = "default"; - - public X509ServiceCertRequest(String csr) throws CryptoException { - super(csr); + public X509ServiceCertRequest(String csr, SpiffeUriManager spiffeUriManager) throws CryptoException { + super(csr, spiffeUriManager); } public boolean validate(final String domainName, final String serviceName, final String provider, @@ -91,26 +85,12 @@ public boolean validate(final String domainName, final String serviceName, final public boolean validateSpiffeURI(final String domainName, final String serviceName, final String namespace) { - // the expected format are: - // spiffe:///sa/ - // e.g. spiffe://sports/sa/api - // spiffe:///ns//sa/ - // e.g. spiffe://athenz.io/ns/default/sa/sports.api + // validate the spiffe uri according to our configured validators if (spiffeUri == null) { return true; } - final String ns = StringUtil.isEmpty(namespace) ? SPIFFE_DEFAULT_NAMESPACE : namespace; - final String reqUri1 = "spiffe://" + domainName + "/" + SPIFFE_SERVICE_AGENT + "/" + serviceName; - final String reqUri2 = "spiffe://" + SPIFFE_TRUST_DOMAIN + "/" + SPIFFE_NAMESPACE_AGENT + "/" + - ns + "/" + SPIFFE_SERVICE_AGENT + "/" + domainName + "." + serviceName; - boolean uriVerified = reqUri1.equalsIgnoreCase(spiffeUri) || reqUri2.equalsIgnoreCase(spiffeUri); - - if (!uriVerified) { - LOGGER.error("validateSpiffeURI: spiffe uri mismatch: {}/{}/{}", spiffeUri, reqUri1, reqUri2); - } - - return uriVerified; + return spiffeUriManager.validateServiceCertUri(spiffeUri, domainName, serviceName, namespace); } } diff --git a/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java b/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java index b6d715789ef..e23ef12ea09 100644 --- a/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java +++ b/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java @@ -4218,7 +4218,7 @@ public void testValidateRoleCertificateRequestMismatchEmail() throws IOException Path path = Paths.get("src/test/resources/valid_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); zts.validCertSubjectOrgValues = null; assertFalse(zts.validateRoleCertificateRequest(certReq, "sports.standings", @@ -4231,7 +4231,7 @@ public void testValidateRoleCertificateRequestNoEmail() throws IOException { Path path = Paths.get("src/test/resources/valid_noemail.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); zts.validCertSubjectOrgValues = null; assertFalse(zts.validateRoleCertificateRequest(certReq, "no-email", null, @@ -4244,7 +4244,7 @@ public void testValidateRoleCertificateRequestInvalidOField() throws IOException Path path = Paths.get("src/test/resources/valid_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); Set validOValues = new HashSet<>(); validOValues.add("InvalidCompany"); @@ -4259,7 +4259,7 @@ public void testValidateRoleCertificateRequest() throws IOException { Path path = Paths.get("src/test/resources/valid_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); zts.validCertSubjectOrgValues = null; assertTrue(zts.validateRoleCertificateRequest(certReq, "sports.scores", @@ -4285,7 +4285,7 @@ public void testValidateRoleCertificateRequestOU() throws IOException { zts.validCertSubjectOrgUnitValues = ouValues; zts.verifyCertSubjectOU = true; - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); assertFalse(zts.validateRoleCertificateRequest(certReq, "sports.scores", null, null, "10.0.0.1")); ouValues.add("Testing Domain"); @@ -4301,7 +4301,7 @@ public void testValidateRoleCertificateRequestWithUriHostname() throws IOExcepti String pem = new String(Files.readAllBytes(path)); X509Certificate cert = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); // if the CSR has hostname, but the cert doesn't have hostname, it should result in false assertFalse(zts.validateRoleCertificateRequest(certReq, "athenz.examples.httpd", @@ -4316,7 +4316,7 @@ public void testValidateRoleCertificateRequestWithUriHostname() throws IOExcepti path = Paths.get("src/test/resources/athenz.examples.role-uri-instanceid-hostname.csr"); csr = new String(Files.readAllBytes(path)); - certReq = new X509RoleCertRequest(csr); + certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); // if CSR has hostname+instanceid, and cert has only hostname, it should result in false assertFalse(zts.validateRoleCertificateRequest(certReq, "athenz.examples.httpd", @@ -4382,7 +4382,7 @@ public void testValidateRoleCertificateRequestOUWithCert() throws IOException { pem = new String(Files.readAllBytes(path)); X509Certificate invalidCert = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); zts.validCertSubjectOrgValues = null; @@ -4408,7 +4408,7 @@ public void testValidateRoleCertificateRequestMismatchIP() throws IOException { String pem = new String(Files.readAllBytes(path)); X509Certificate cert = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, zts.spiffeUriManager); // disable IP validation and we should get success @@ -4440,7 +4440,7 @@ public void testProcessRoleCertificateRequestFailedValidation() { RoleCertificateRequest req = new RoleCertificateRequest(); - X509RoleCertRequest certReq = new X509RoleCertRequest(ROLE_CERT_CORETECH_REQUEST); + X509RoleCertRequest certReq = new X509RoleCertRequest(ROLE_CERT_CORETECH_REQUEST, zts.spiffeUriManager); Set origUnitValues = zts.validCertSubjectOrgUnitValues; boolean verifyCertSubjectOU = zts.verifyCertSubjectOU; @@ -9065,7 +9065,7 @@ public void testValidateServiceX509RefreshRequest() throws IOException { Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, ztsImpl.spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/valid_provider_refresh.pem"); @@ -9092,7 +9092,7 @@ public void testValidateServiceX509RefreshRequestMismatchPublicKeys() throws IOE Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, ztsImpl.spiffeUriManager); assertNotNull(certReq); certReq.setNormCsrPublicKey("mismatch-public-key"); @@ -9120,7 +9120,7 @@ public void testValidateServiceX509RefreshRequestNotAllowedIP() throws IOExcepti Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, ztsImpl.spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/valid_provider_refresh.pem"); @@ -9149,7 +9149,7 @@ public void testValidateServiceX509RefreshRequestMismatchDns() throws IOExceptio Path path = Paths.get("src/test/resources/athenz.mismatch.dns.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, ztsImpl.spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/athenz.instanceid.pem"); @@ -14606,7 +14606,7 @@ public void testGenerateInstanceConfirmObjectWithCtxCert() throws IOException { path = Paths.get("src/test/resources/athenz.instanceid.csr"); String certCsr = new String(Files.readAllBytes(path)); - X509CertRequest certRequest = new X509ServiceCertRequest(certCsr); + X509CertRequest certRequest = new X509ServiceCertRequest(certCsr, zts.spiffeUriManager); InstanceConfirmation confirmation = ztsImpl.newInstanceConfirmationForRegister(context, "secureboot.provider", diff --git a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509CertRequestTest.java b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509CertRequestTest.java index 2357fa2a3eb..9d425abd4ee 100644 --- a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509CertRequestTest.java +++ b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509CertRequestTest.java @@ -25,6 +25,7 @@ import static org.testng.Assert.*; import com.yahoo.athenz.common.server.dns.HostnameResolver; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.zts.CertType; import com.yahoo.athenz.zts.cache.DataCache; import com.yahoo.athenz.zts.cert.impl.TestHostnameResolver; @@ -37,13 +38,15 @@ public class X509CertRequestTest { + final SpiffeUriManager spiffeUriManager = new SpiffeUriManager(); + @Test public void testConstructorValidCsr() throws IOException { Path path = Paths.get("src/test/resources/valid_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); } @@ -52,7 +55,7 @@ public void testConstructorInvalidCsr() { X509CertRequest certReq = null; try { - certReq = new X509CertRequest("csr"); + certReq = new X509CertRequest("csr", spiffeUriManager); fail(); } catch (CryptoException ignored) { } @@ -63,14 +66,14 @@ public void testConstructorInvalidCsr() { public void testConstructorValidUriHostname() throws IOException { Path path = Paths.get("src/test/resources/athenz.examples.uri-instanceid-hostname.csr"); - X509CertRequest certReq = new X509CertRequest(new String(Files.readAllBytes(path))); + X509CertRequest certReq = new X509CertRequest(new String(Files.readAllBytes(path)), spiffeUriManager); assertNotNull(certReq); assertEquals(certReq.getUriHostname(), "abc.athenz.com"); path = Paths.get("src/test/resources/athenz.examples.uri-hostname-only.csr"); - certReq = new X509CertRequest(new String(Files.readAllBytes(path))); + certReq = new X509CertRequest(new String(Files.readAllBytes(path)), spiffeUriManager); assertNotNull(certReq); assertEquals(certReq.getUriHostname(), "abc.athenz.com"); } @@ -80,7 +83,7 @@ public void testParseCertRequestIPs() throws IOException { Path path = Paths.get("src/test/resources/multiple_ips.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List values = certReq.getDnsNames(); @@ -99,7 +102,7 @@ public void testParseCertRequestInvalid() throws IOException { Path path = Paths.get("src/test/resources/invalid_dns.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); } @@ -108,7 +111,7 @@ public void testValidateCommonName() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateCommonName("athenz.production")); @@ -123,7 +126,7 @@ public void testValidateUriHostname() throws IOException { Path path = Paths.get("src/test/resources/athenz.examples.uri-instanceid-hostname.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateUriHostname("abc.athenz.com")); @@ -136,7 +139,7 @@ public void testValidateUriHostname() throws IOException { path = Paths.get("src/test/resources/athenz.examples.uri-hostname-empty.csr"); csr = new String(Files.readAllBytes(path)); - certReq = new X509CertRequest(csr); + certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateUriHostname("abc.athenz.com")); } @@ -146,7 +149,7 @@ public void testInstanceId() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertEquals(certReq.getInstanceId(), "1001"); @@ -158,7 +161,7 @@ public void testValidateDnsNamesWithCert() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/athenz.instanceid.pem"); @@ -174,7 +177,7 @@ public void testValidateDnsNamesWithValues() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List providerDnsSuffixList = new ArrayList<>(); @@ -233,7 +236,7 @@ public void testValidateDnsNamesWithCnameValues() throws IOException { String csr = new String(Files.readAllBytes(path)); String service = "athenz.production"; - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); DataCache athenzSysDomainCache = Mockito.mock(DataCache.class); @@ -276,7 +279,7 @@ public void testValidateDnsNamesWithCnameValuesWithSameSuffix() throws IOExcepti String csr = new String(Files.readAllBytes(path)); String service = "athenz.production"; - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); DataCache athenzSysDomainCache = Mockito.mock(DataCache.class); @@ -310,7 +313,7 @@ public void testValidateDnsNamesWithMultipleDomainValues() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List providerDnsSuffixList = new ArrayList<>(); @@ -383,7 +386,7 @@ public void testValidateUri() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List providerDnsSuffixList = new ArrayList<>(); @@ -414,7 +417,7 @@ public void testValidateDnsNamesHostnameNullLists() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List providerDnsSuffixList = new ArrayList<>(); @@ -444,7 +447,7 @@ public void testValidateDnsNamesHostnameNotAllowed() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); List providerDnsSuffixList = new ArrayList<>(); @@ -501,7 +504,7 @@ public void testValidateProviderDnsNamesList() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); // now add the hostname to the list @@ -534,7 +537,7 @@ public void testValidateProviderDnsNamesListWithWildcard() throws IOException { Path path = Paths.get("src/test/resources/multi_dns_domain_wildcard.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); // now add the hostname to the list @@ -565,7 +568,7 @@ public void testValidateProviderDnsNamesListWithWildcardMismatch() throws IOExce Path path = Paths.get("src/test/resources/multi_dns_domain_wildcard_mismatch.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); // now add the hostname to the list @@ -598,7 +601,7 @@ public void testValidateDnsNamesNoValues() throws IOException { Path path = Paths.get("src/test/resources/valid_cn_only.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); DataCache athenzSysDomainCache = Mockito.mock(DataCache.class); @@ -616,7 +619,7 @@ public void testValidateDnsNamesMismatchSize() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/valid_cn_x509.cert"); @@ -632,7 +635,7 @@ public void testValidateDnsNamesMismatchValues() throws IOException { Path path = Paths.get("src/test/resources/athenz.mismatch.dns.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/athenz.instanceid.pem"); @@ -648,7 +651,7 @@ public void testValidatePublicKeysCert() throws IOException { Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/valid_provider_refresh.pem"); @@ -664,7 +667,7 @@ public void testValidatePublicKeysCertFailure() throws IOException { Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); X509Certificate cert = Mockito.mock(X509Certificate.class); @@ -679,7 +682,7 @@ public void testValidatePublicKeysCertCSRFailure() throws IOException { Path path = Paths.get("src/test/resources/valid_provider_refresh.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); PKCS10CertificationRequest req = Mockito.mock(PKCS10CertificationRequest.class); @@ -699,7 +702,7 @@ public void testValidatePublicKeysCertMismatch() throws IOException { Path path = Paths.get("src/test/resources/athenz.mismatch.dns.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/athenz.instanceid.pem"); @@ -715,7 +718,7 @@ public void testValidatePublicKeysNull() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertFalse(certReq.validatePublicKeys((String) null)); @@ -727,7 +730,7 @@ public void testValidatePublicKeysFailure() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); PKCS10CertificationRequest req = Mockito.mock(PKCS10CertificationRequest.class); @@ -742,7 +745,7 @@ public void testValidatePublicKeysString() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); final String ztsPublicKey = "-----BEGIN PUBLIC KEY-----\n" + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKrvfvBgXWqWAorw5hYJu3dpOJe0gp3n\n" @@ -757,7 +760,7 @@ public void testValidateCertReqPublicKey() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); final String ztsPublicKey = "-----BEGIN PUBLIC KEY-----\n" @@ -773,7 +776,7 @@ public void testValidateCertReqPublicKeyMismatch() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); final String ztsPublicKey = "-----BEGIN PUBLIC KEY-----\n" @@ -789,7 +792,7 @@ public void testValidateCertReqPublicKeyWhitespace() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); final String ztsPublicKey1 = " -----BEGIN PUBLIC KEY-----\n" @@ -812,7 +815,7 @@ public void testValidateCertCNFailure() throws IOException { String csr = new String(Files.readAllBytes(path)); try { - new X509CertRequest(csr); + new X509CertRequest(csr, spiffeUriManager); fail(); } catch (CryptoException ex) { assertTrue(ex.getMessage().contains("Subject contains multiple values")); @@ -827,7 +830,7 @@ public void testValidateOUFieldCheck() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); testValidateOUFieldCheck(certReq); @@ -838,7 +841,7 @@ public void testValidateOUFieldCheck() throws IOException { path = Paths.get("src/test/resources/athenz.instanceid.restricted.csr"); csr = new String(Files.readAllBytes(path)); - certReq = new X509CertRequest(csr); + certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); testValidateOUFieldCheck(certReq); @@ -878,7 +881,7 @@ public void testValidateOUFieldCheckMissingOU() throws IOException { Path path = Paths.get("src/test/resources/athenz.single_ip.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); HashSet validOrgUnits = new HashSet<>(); @@ -899,7 +902,7 @@ public void testValidateOUFieldCheckInvalidOU() throws IOException { Path path = Paths.get("src/test/resources/athenz.multiple_ou.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertFalse(certReq.validateSubjectOUField("Athenz", null, null)); @@ -912,7 +915,7 @@ public void testExtractInstanceIdURI() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.uri.csr"); String csr = new String(Files.readAllBytes(path)); - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertEquals(certReq.getInstanceId(), "id-001"); @@ -925,7 +928,7 @@ public void testValidateInstanceCnames() throws IOException { String csr = new String(Files.readAllBytes(path)); String service = "athenz.api"; - X509CertRequest certReq = new X509CertRequest(csr); + X509CertRequest certReq = new X509CertRequest(csr, spiffeUriManager); assertNotNull(certReq); // cnames null and empty is always true diff --git a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509RoleCertRequestTest.java b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509RoleCertRequestTest.java index 5189e84bb89..6c0f676551d 100644 --- a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509RoleCertRequestTest.java +++ b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509RoleCertRequestTest.java @@ -16,6 +16,7 @@ package com.yahoo.athenz.zts.cert; import com.yahoo.athenz.auth.util.Crypto; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import org.testng.annotations.Test; import java.io.IOException; @@ -29,12 +30,14 @@ public class X509RoleCertRequestTest { + final SpiffeUriManager spiffeUriManager = new SpiffeUriManager(); + @Test public void testX509RoleCertRequest() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertEquals(certReq.getReqRoleDomain(), "coretech"); @@ -55,7 +58,7 @@ public void testValidateSpiffeRoleCert() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); Set orgValues = new HashSet<>(); orgValues.add("Athenz"); @@ -69,7 +72,7 @@ public void testValidateRoleIPAddressNoIPs() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateIPAddress(null, "10.10.11.12")); } @@ -79,7 +82,7 @@ public void testValidateRoleIPAddressNoCert() throws IOException { Path path = Paths.get("src/test/resources/role_single_ip.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateIPAddress(null, "10.11.12.13")); assertFalse(certReq.validateIPAddress(null, "10.10.11.12")); } @@ -94,7 +97,7 @@ public void testValidateRoleIPAddressCertNoIPs() throws IOException { String pem = new String(Files.readAllBytes(path)); X509Certificate cert = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateIPAddress(cert, "10.11.12.13")); assertFalse(certReq.validateIPAddress(cert, "10.10.11.12")); } @@ -113,7 +116,7 @@ public void testValidateRoleIPAddressCertIPs() throws IOException { pem = new String(Files.readAllBytes(path)); X509Certificate cert2 = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateIPAddress(cert1, "10.11.12.13")); assertTrue(certReq.validateIPAddress(cert2, "10.11.12.13")); } @@ -132,7 +135,7 @@ public void testValidateRoleIPAddressCertMultipleIPs() throws IOException { pem = new String(Files.readAllBytes(path)); X509Certificate cert2 = Crypto.loadX509Certificate(pem); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertFalse(certReq.validateIPAddress(cert1, "10.11.12.13")); assertTrue(certReq.validateIPAddress(cert2, "10.11.12.13")); } @@ -143,7 +146,7 @@ public void testValidateMissingProxyUserUri() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); Set orgValues = new HashSet<>(); orgValues.add("Athenz"); @@ -157,7 +160,7 @@ public void testValidateNoProxyUserUri() throws IOException { Path path = Paths.get("src/test/resources/role_single_ip.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); Set orgValues = new HashSet<>(); orgValues.add("Athenz"); @@ -171,7 +174,7 @@ public void testValidateMultipleProxyUserUri() throws IOException { Path path = Paths.get("src/test/resources/multiple_proxy_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); Set orgValues = new HashSet<>(); orgValues.add("Athenz"); @@ -185,7 +188,7 @@ public void testValidateProxyUserUri() throws IOException { Path path = Paths.get("src/test/resources/proxy_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); Set orgValues = new HashSet<>(); orgValues.add("Athenz"); @@ -203,7 +206,7 @@ public void testRoleCertValidatePrincipalURINoEmail() throws IOException { Path path = Paths.get("src/test/resources/athenz_role_principal_uri.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validate("athenz.production", null, null)); assertFalse(certReq.validate("athenz.api", null, null)); } @@ -214,7 +217,7 @@ public void testRoleCertValidatePrincipalURIWithEmail() throws IOException { Path path = Paths.get("src/test/resources/athenz_role_principal_uri_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validate("athenz.production", null, null)); assertFalse(certReq.validate("athenz.api", null, null)); } @@ -225,7 +228,7 @@ public void testRoleCertValidatePrincipalURIWithEmailMismatch() throws IOExcepti Path path = Paths.get("src/test/resources/athenz_role_principal_uri_email_mismatch.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertFalse(certReq.validate("athenz.production", null, null)); assertFalse(certReq.validate("athenz.api", null, null)); } @@ -236,7 +239,7 @@ public void testValidateSpiffeURIWithoutTrustDomain() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateSpiffeURI("coretech", "api")); assertFalse(certReq.validateSpiffeURI("coretech", "backend")); } @@ -247,7 +250,7 @@ public void testValidateSpiffeURIWithTrustDomain() throws IOException { Path path = Paths.get("src/test/resources/spiffe_role_trust_domain.csr"); String csr = new String(Files.readAllBytes(path)); - X509RoleCertRequest certReq = new X509RoleCertRequest(csr); + X509RoleCertRequest certReq = new X509RoleCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateSpiffeURI("coretech", "api")); assertFalse(certReq.validateSpiffeURI("coretech", "backend")); } diff --git a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequestTest.java b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequestTest.java index 508f2d0c58d..149ab9461e9 100644 --- a/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequestTest.java +++ b/servers/zts/src/test/java/com/yahoo/athenz/zts/cert/X509ServiceCertRequestTest.java @@ -17,6 +17,7 @@ import com.yahoo.athenz.auth.util.Crypto; import com.yahoo.athenz.auth.util.CryptoException; +import com.yahoo.athenz.common.server.spiffe.SpiffeUriManager; import com.yahoo.athenz.zts.cache.DataCache; import org.mockito.Mockito; import org.testng.annotations.DataProvider; @@ -33,13 +34,15 @@ public class X509ServiceCertRequestTest { + final SpiffeUriManager spiffeUriManager = new SpiffeUriManager(); + @Test public void testValidateInvalidDnsNames() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -53,7 +56,7 @@ public void testValidateInvalidInstanceId() throws IOException { Path path = Paths.get("src/test/resources/valid_email.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -72,7 +75,7 @@ public void testValidateInstanceIdMismatch() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); path = Paths.get("src/test/resources/athenz.instanceid.pem"); @@ -89,7 +92,7 @@ public void testValidateCnMismatch() throws IOException { Path path = Paths.get("src/test/resources/athenz.mismatch.cn.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -109,7 +112,7 @@ public void testValidateDnsSuffixMismatch() throws IOException { Path path = Paths.get("src/test/resources/athenz.mismatch.dns.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -129,7 +132,7 @@ public void testValidateOFieldCheck() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -156,7 +159,7 @@ public void testValidateOFieldCheckNoValue() throws IOException { Path path = Paths.get("src/test/resources/valid_cn_only.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); HashSet validOrgs = new HashSet<>(); @@ -171,7 +174,7 @@ public void testValidateOFieldCheckMultipleValue() throws IOException { Path path = Paths.get("src/test/resources/multiple_org.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); HashSet validOrgs = new HashSet<>(); @@ -186,7 +189,7 @@ public void testValidate() throws IOException { Path path = Paths.get("src/test/resources/athenz.instanceid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -225,7 +228,7 @@ public void testValidateSpiffeUri(final String csrPath, boolean expectedResult) Path path = Paths.get(csrPath); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -247,7 +250,7 @@ public void testValidateIPAddressMultipleIPs() throws IOException { Path path = Paths.get("src/test/resources/multiple_ips.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertFalse(certReq.validateIPAddress("10.11.12.14")); @@ -259,7 +262,7 @@ public void testValidateIPAddressNoIPs() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateIPAddress("10.11.12.14")); @@ -271,7 +274,7 @@ public void testValidateIPAddressMismatchIPs() throws IOException { Path path = Paths.get("src/test/resources/athenz.single_ip.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertFalse(certReq.validateIPAddress("10.11.12.14")); @@ -283,7 +286,7 @@ public void testValidateIPAddress() throws IOException { Path path = Paths.get("src/test/resources/athenz.single_ip.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateIPAddress("10.11.12.13")); @@ -295,7 +298,7 @@ public void testValidateUriHostname() throws IOException { Path path = Paths.get("src/test/resources/athenz.examples.uri-instanceid-hostname.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); assertTrue(certReq.validateUriHostname("abc.athenz.com")); @@ -313,7 +316,7 @@ public void testValidateWithUriHostname() throws IOException { Path path = Paths.get("src/test/resources/athenz.examples.uri-instanceid-hostname.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertNotNull(certReq); StringBuilder errorMsg = new StringBuilder(256); @@ -344,7 +347,7 @@ public void testValidateSpiffeURIWithoutURI() throws IOException { Path path = Paths.get("src/test/resources/valid.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateSpiffeURI("domain", "api", null)); assertTrue(certReq.validateSpiffeURI("domain", "api", "default")); } @@ -355,7 +358,7 @@ public void testValidateSpiffeURIWithNamespace() throws IOException { Path path = Paths.get("src/test/resources/spiffe-namespace.csr"); String csr = new String(Files.readAllBytes(path)); - X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr); + X509ServiceCertRequest certReq = new X509ServiceCertRequest(csr, spiffeUriManager); assertTrue(certReq.validateSpiffeURI("athenz", "production", "default")); assertFalse(certReq.validateSpiffeURI("athenz", "production", "test")); @@ -372,7 +375,7 @@ public void testValidateSpiffeURIMultipleValues() throws IOException { String csr = new String(Files.readAllBytes(path)); try { - new X509ServiceCertRequest(csr); + new X509ServiceCertRequest(csr, spiffeUriManager); fail(); } catch (CryptoException ex) { assertTrue(ex.getMessage().contains("Invalid SPIFFE URI present"));