From 444ede9c22254c92aa5d4a7ee3ff2af2f5b87e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Mart=C3=ADnez=20Fay=C3=B3?= Date: Thu, 10 Oct 2024 12:04:03 -0300 Subject: [PATCH] Get the key ID only during tainting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Martínez Fayó --- pkg/agent/client/client.go | 27 -------------------- pkg/agent/client/client_test.go | 14 +++++----- pkg/agent/manager/cache/jwt_cache.go | 31 +++++++++++++++++++++-- pkg/agent/manager/cache/jwt_cache_test.go | 8 +++--- pkg/agent/manager/manager_test.go | 19 +++++++------- 5 files changed, 49 insertions(+), 50 deletions(-) diff --git a/pkg/agent/client/client.go b/pkg/agent/client/client.go index 79ce88efd0..1ec3522f58 100644 --- a/pkg/agent/client/client.go +++ b/pkg/agent/client/client.go @@ -11,7 +11,6 @@ import ( "sync" "time" - "github.com/go-jose/go-jose/v4/jwt" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" @@ -20,7 +19,6 @@ import ( svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/bundleutil" - "github.com/spiffe/spire/pkg/common/jwtsvid" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/proto/spire/common" "google.golang.org/grpc" @@ -54,7 +52,6 @@ type JWTSVID struct { Token string IssuedAt time.Time ExpiresAt time.Time - Kid string } type SyncStats struct { @@ -335,16 +332,10 @@ func (c *client) NewJWTSVID(ctx context.Context, entryID string, audience []stri return nil, errors.New("JWTSVID issued after it has expired") } - keyID, err := getKeyIDFromSVIDToken(svid.Token) - if err != nil { - return nil, err - } - return &JWTSVID{ Token: svid.Token, IssuedAt: time.Unix(svid.IssuedAt, 0).UTC(), ExpiresAt: time.Unix(svid.ExpiresAt, 0).UTC(), - Kid: keyID, }, nil } @@ -735,21 +726,3 @@ func (c *client) withErrorFields(err error) logrus.FieldLogger { return logger } - -func getKeyIDFromSVIDToken(svidToken string) (string, error) { - token, err := jwt.ParseSigned(svidToken, jwtsvid.AllowedSignatureAlgorithms) - if err != nil { - return "", fmt.Errorf("failed to parse JWT-SVID: %w", err) - } - - if len(token.Headers) != 1 { - return "", fmt.Errorf("malformed JWT-SVID: expected a single token header; got %d", len(token.Headers)) - } - - keyID := token.Headers[0].KeyID - if keyID == "" { - return "", errors.New("missing key ID in token header of minted JWT-SVID") - } - - return keyID, nil -} diff --git a/pkg/agent/client/client_test.go b/pkg/agent/client/client_test.go index eb7b9c7097..32ddae38e0 100644 --- a/pkg/agent/client/client_test.go +++ b/pkg/agent/client/client_test.go @@ -882,7 +882,6 @@ func TestFetchJWTSVID(t *testing.T) { issuedAt := time.Now().Unix() expiresAt := time.Now().Add(time.Minute).Unix() - tok := "eyJhbGciOiJFUzI1NiIsImtpZCI6InlueVV3TEg5bklOaWcyWEdRdTBkdklWbndieG5JeHlPIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwOTIsImlhdCI6MTcyNDI3OTc0Nywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.feVAZPLmWT4ohzNKBoil90WBn64noqnCYuXXJjhItdtsOmJqMBm-blfJl4pBvbRBCWym2YaXK9gl9RoAfFG0zQ" for _, tt := range []struct { name string setupTest func(err error) @@ -894,17 +893,16 @@ func TestFetchJWTSVID(t *testing.T) { name: "success", setupTest: func(err error) { tc.svidServer.jwtSVID = &types.JWTSVID{ - Token: tok, + Token: "token", ExpiresAt: expiresAt, IssuedAt: issuedAt, } tc.svidServer.newJWTSVID = err }, expectSVID: &JWTSVID{ - Token: tok, + Token: "token", ExpiresAt: time.Unix(expiresAt, 0).UTC(), IssuedAt: time.Unix(issuedAt, 0).UTC(), - Kid: "ynyUwLH9nINig2XGQu0dvIVnwbxnIxyO", }, }, { @@ -927,7 +925,7 @@ func TestFetchJWTSVID(t *testing.T) { name: "missing issuedAt", setupTest: func(err error) { tc.svidServer.jwtSVID = &types.JWTSVID{ - Token: tok, + Token: "token", ExpiresAt: expiresAt, } tc.svidServer.newJWTSVID = err @@ -938,7 +936,7 @@ func TestFetchJWTSVID(t *testing.T) { name: "missing expiredAt", setupTest: func(err error) { tc.svidServer.jwtSVID = &types.JWTSVID{ - Token: tok, + Token: "token", IssuedAt: issuedAt, } tc.svidServer.newJWTSVID = err @@ -949,7 +947,7 @@ func TestFetchJWTSVID(t *testing.T) { name: "issued after expired", setupTest: func(err error) { tc.svidServer.jwtSVID = &types.JWTSVID{ - Token: tok, + Token: "token", ExpiresAt: issuedAt, IssuedAt: expiresAt, } @@ -961,7 +959,7 @@ func TestFetchJWTSVID(t *testing.T) { name: "grpc call to NewJWTSVID fails", setupTest: func(err error) { tc.svidServer.jwtSVID = &types.JWTSVID{ - Token: tok, + Token: "token", ExpiresAt: expiresAt, IssuedAt: issuedAt, } diff --git a/pkg/agent/manager/cache/jwt_cache.go b/pkg/agent/manager/cache/jwt_cache.go index d289817ddf..b9f6a98393 100644 --- a/pkg/agent/manager/cache/jwt_cache.go +++ b/pkg/agent/manager/cache/jwt_cache.go @@ -4,14 +4,18 @@ import ( "context" "crypto/sha256" "encoding/base64" + "errors" + "fmt" "io" "sort" "strings" "sync" + "github.com/go-jose/go-jose/v4/jwt" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/client" + "github.com/spiffe/spire/pkg/common/jwtsvid" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/telemetry/agent" ) @@ -62,9 +66,14 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities var taintedKeyIDs []string svidsRemoved := 0 for key, jwtSVID := range c.svids { - if _, tainted := taintedJWTAuthorities[jwtSVID.Kid]; tainted { + keyID, err := getKeyIDFromSVIDToken(jwtSVID.Token) + if err != nil { + c.log.Error(err) + continue + } + if _, tainted := taintedJWTAuthorities[keyID]; tainted { delete(c.svids, key) - taintedKeyIDs = append(taintedKeyIDs, jwtSVID.Kid) + taintedKeyIDs = append(taintedKeyIDs, keyID) svidsRemoved++ } select { @@ -83,6 +92,24 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities agent.AddCacheManagerTaintedJWTSVIDsSample(c.metrics, agent.CacheTypeWorkload, float32(taintedKeyIDsCount)) } +func getKeyIDFromSVIDToken(svidToken string) (string, error) { + token, err := jwt.ParseSigned(svidToken, jwtsvid.AllowedSignatureAlgorithms) + if err != nil { + return "", fmt.Errorf("failed to parse JWT-SVID: %w", err) + } + + if len(token.Headers) != 1 { + return "", fmt.Errorf("malformed JWT-SVID: expected a single token header; got %d", len(token.Headers)) + } + + keyID := token.Headers[0].KeyID + if keyID == "" { + return "", errors.New("missing key ID in token header of minted JWT-SVID") + } + + return keyID, nil +} + func jwtSVIDKey(spiffeID spiffeid.ID, audience []string) string { h := sha256.New() diff --git a/pkg/agent/manager/cache/jwt_cache_test.go b/pkg/agent/manager/cache/jwt_cache_test.go index b9331d4deb..cc3402c4ec 100644 --- a/pkg/agent/manager/cache/jwt_cache_test.go +++ b/pkg/agent/manager/cache/jwt_cache_test.go @@ -19,7 +19,9 @@ import ( func TestJWTSVIDCacheBasic(t *testing.T) { now := time.Now() - expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Second), Kid: "the-kid"} + tok := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg" + keyID := "dZDfYiw1uGzMwdMYHL7FEYyK8HOKKwLX" + expected := &client.JWTSVID{Token: tok, IssuedAt: now, ExpiresAt: now.Add(time.Second)} fakeMetrics := fakemetrics.New() log, hook := test.NewNullLogger() @@ -40,7 +42,7 @@ func TestJWTSVIDCacheBasic(t *testing.T) { assert.Equal(t, expected, actual) // Remove tainted authority, should not be cached anymore - cache.TaintJWTSVIDs(context.Background(), map[string]struct{}{"the-kid": {}}) + cache.TaintJWTSVIDs(context.Background(), map[string]struct{}{keyID: {}}) actual, ok = cache.GetJWTSVID(spiffeID, []string{"bar"}) assert.False(t, ok) assert.Nil(t, actual) @@ -52,7 +54,7 @@ func TestJWTSVIDCacheBasic(t *testing.T) { Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority", Data: logrus.Fields{ telemetry.CountJWTSVIDs: "1", - telemetry.JWTAuthorityKeyIDs: "the-kid", + telemetry.JWTAuthorityKeyIDs: keyID, }, }, } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index a4f69f255e..695825e00c 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -1531,47 +1531,46 @@ func TestFetchJWTSVID(t *testing.T) { now := clk.Now() // fetch succeeds - tokA := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg" + tokenA := "A" issuedAtA := now.Unix() expiresAtA := now.Add(time.Minute).Unix() fetchResp.Svid = &types.JWTSVID{ - Token: tokA, + Token: tokenA, IssuedAt: issuedAtA, ExpiresAt: expiresAtA, } svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience) require.NoError(t, err) - require.Equal(t, tokA, svid.Token) + require.Equal(t, tokenA, svid.Token) require.Equal(t, issuedAtA, svid.IssuedAt.Unix()) require.Equal(t, expiresAtA, svid.ExpiresAt.Unix()) // assert cached JWT is returned w/o trying to fetch (since cached version does not expire soon) - tokB := "eyJhbGciOiJFUzI1NiIsImtpZCI6InQ4ajd4cmdtVDVGRzNkbG45ZnFYanlNQzF0emxyT1B6IiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwNDksImlhdCI6MTcyNDI3OTY1OSwic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.68FB0ubhQsirJE_K0QSQxn_b06OZaWNrgC3nJCDupHEbSftXuGiAFDqUqK_HGMoKC1Nz9bxBNNQSoL50H4C3vw" fetchResp.Svid = &types.JWTSVID{ - Token: tokB, + Token: "B", IssuedAt: now.Unix(), ExpiresAt: now.Add(time.Minute).Unix(), } svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience) require.NoError(t, err) - require.Equal(t, tokA, svid.Token) + require.Equal(t, tokenA, svid.Token) require.Equal(t, issuedAtA, svid.IssuedAt.Unix()) require.Equal(t, expiresAtA, svid.ExpiresAt.Unix()) // expire the cached JWT soon and make sure new JWT is fetched clk.Add(time.Second * 45) now = clk.Now() - tokC := "eyJhbGciOiJFUzI1NiIsImtpZCI6InlueVV3TEg5bklOaWcyWEdRdTBkdklWbndieG5JeHlPIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwOTIsImlhdCI6MTcyNDI3OTc0Nywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.feVAZPLmWT4ohzNKBoil90WBn64noqnCYuXXJjhItdtsOmJqMBm-blfJl4pBvbRBCWym2YaXK9gl9RoAfFG0zQ" + tokenC := "C" issuedAtC := now.Unix() expiresAtC := now.Add(time.Minute).Unix() fetchResp.Svid = &types.JWTSVID{ - Token: tokC, + Token: tokenC, IssuedAt: issuedAtC, ExpiresAt: expiresAtC, } svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience) require.NoError(t, err) - require.Equal(t, tokC, svid.Token) + require.Equal(t, tokenC, svid.Token) require.Equal(t, issuedAtC, svid.IssuedAt.Unix()) require.Equal(t, expiresAtC, svid.ExpiresAt.Unix()) @@ -1580,7 +1579,7 @@ func TestFetchJWTSVID(t *testing.T) { fetchResp.Svid = nil svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience) require.NoError(t, err) - require.Equal(t, tokC, svid.Token) + require.Equal(t, tokenC, svid.Token) require.Equal(t, issuedAtC, svid.IssuedAt.Unix()) require.Equal(t, expiresAtC, svid.ExpiresAt.Unix())