Skip to content

Commit

Permalink
Get the key ID only during tainting
Browse files Browse the repository at this point in the history
Signed-off-by: Agustín Martínez Fayó <[email protected]>
  • Loading branch information
amartinezfayo committed Oct 10, 2024
1 parent 3e0bef3 commit 444ede9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 50 deletions.
27 changes: 0 additions & 27 deletions pkg/agent/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"sync"
"time"

"github.com/go-jose/go-jose/v4/jwt"
"github.com/sirupsen/logrus"
"github.com/spiffe/go-spiffe/v2/spiffeid"
agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1"
Expand All @@ -20,7 +19,6 @@ import (
svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1"
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
"github.com/spiffe/spire/pkg/common/bundleutil"
"github.com/spiffe/spire/pkg/common/jwtsvid"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/proto/spire/common"
"google.golang.org/grpc"
Expand Down Expand Up @@ -54,7 +52,6 @@ type JWTSVID struct {
Token string
IssuedAt time.Time
ExpiresAt time.Time
Kid string
}

type SyncStats struct {
Expand Down Expand Up @@ -335,16 +332,10 @@ func (c *client) NewJWTSVID(ctx context.Context, entryID string, audience []stri
return nil, errors.New("JWTSVID issued after it has expired")
}

keyID, err := getKeyIDFromSVIDToken(svid.Token)
if err != nil {
return nil, err
}

return &JWTSVID{
Token: svid.Token,
IssuedAt: time.Unix(svid.IssuedAt, 0).UTC(),
ExpiresAt: time.Unix(svid.ExpiresAt, 0).UTC(),
Kid: keyID,
}, nil
}

Expand Down Expand Up @@ -735,21 +726,3 @@ func (c *client) withErrorFields(err error) logrus.FieldLogger {

return logger
}

func getKeyIDFromSVIDToken(svidToken string) (string, error) {
token, err := jwt.ParseSigned(svidToken, jwtsvid.AllowedSignatureAlgorithms)
if err != nil {
return "", fmt.Errorf("failed to parse JWT-SVID: %w", err)
}

if len(token.Headers) != 1 {
return "", fmt.Errorf("malformed JWT-SVID: expected a single token header; got %d", len(token.Headers))
}

keyID := token.Headers[0].KeyID
if keyID == "" {
return "", errors.New("missing key ID in token header of minted JWT-SVID")
}

return keyID, nil
}
14 changes: 6 additions & 8 deletions pkg/agent/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,6 @@ func TestFetchJWTSVID(t *testing.T) {

issuedAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Minute).Unix()
tok := "eyJhbGciOiJFUzI1NiIsImtpZCI6InlueVV3TEg5bklOaWcyWEdRdTBkdklWbndieG5JeHlPIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwOTIsImlhdCI6MTcyNDI3OTc0Nywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.feVAZPLmWT4ohzNKBoil90WBn64noqnCYuXXJjhItdtsOmJqMBm-blfJl4pBvbRBCWym2YaXK9gl9RoAfFG0zQ"
for _, tt := range []struct {
name string
setupTest func(err error)
Expand All @@ -894,17 +893,16 @@ func TestFetchJWTSVID(t *testing.T) {
name: "success",
setupTest: func(err error) {
tc.svidServer.jwtSVID = &types.JWTSVID{
Token: tok,
Token: "token",
ExpiresAt: expiresAt,
IssuedAt: issuedAt,
}
tc.svidServer.newJWTSVID = err
},
expectSVID: &JWTSVID{
Token: tok,
Token: "token",
ExpiresAt: time.Unix(expiresAt, 0).UTC(),
IssuedAt: time.Unix(issuedAt, 0).UTC(),
Kid: "ynyUwLH9nINig2XGQu0dvIVnwbxnIxyO",
},
},
{
Expand All @@ -927,7 +925,7 @@ func TestFetchJWTSVID(t *testing.T) {
name: "missing issuedAt",
setupTest: func(err error) {
tc.svidServer.jwtSVID = &types.JWTSVID{
Token: tok,
Token: "token",
ExpiresAt: expiresAt,
}
tc.svidServer.newJWTSVID = err
Expand All @@ -938,7 +936,7 @@ func TestFetchJWTSVID(t *testing.T) {
name: "missing expiredAt",
setupTest: func(err error) {
tc.svidServer.jwtSVID = &types.JWTSVID{
Token: tok,
Token: "token",
IssuedAt: issuedAt,
}
tc.svidServer.newJWTSVID = err
Expand All @@ -949,7 +947,7 @@ func TestFetchJWTSVID(t *testing.T) {
name: "issued after expired",
setupTest: func(err error) {
tc.svidServer.jwtSVID = &types.JWTSVID{
Token: tok,
Token: "token",
ExpiresAt: issuedAt,
IssuedAt: expiresAt,
}
Expand All @@ -961,7 +959,7 @@ func TestFetchJWTSVID(t *testing.T) {
name: "grpc call to NewJWTSVID fails",
setupTest: func(err error) {
tc.svidServer.jwtSVID = &types.JWTSVID{
Token: tok,
Token: "token",
ExpiresAt: expiresAt,
IssuedAt: issuedAt,
}
Expand Down
31 changes: 29 additions & 2 deletions pkg/agent/manager/cache/jwt_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"sort"
"strings"
"sync"

"github.com/go-jose/go-jose/v4/jwt"
"github.com/sirupsen/logrus"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/agent/client"
"github.com/spiffe/spire/pkg/common/jwtsvid"
"github.com/spiffe/spire/pkg/common/telemetry"
"github.com/spiffe/spire/pkg/common/telemetry/agent"
)
Expand Down Expand Up @@ -62,9 +66,14 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities
var taintedKeyIDs []string
svidsRemoved := 0
for key, jwtSVID := range c.svids {
if _, tainted := taintedJWTAuthorities[jwtSVID.Kid]; tainted {
keyID, err := getKeyIDFromSVIDToken(jwtSVID.Token)
if err != nil {
c.log.Error(err)
continue
}
if _, tainted := taintedJWTAuthorities[keyID]; tainted {
delete(c.svids, key)
taintedKeyIDs = append(taintedKeyIDs, jwtSVID.Kid)
taintedKeyIDs = append(taintedKeyIDs, keyID)
svidsRemoved++
}
select {
Expand All @@ -83,6 +92,24 @@ func (c *JWTSVIDCache) TaintJWTSVIDs(ctx context.Context, taintedJWTAuthorities
agent.AddCacheManagerTaintedJWTSVIDsSample(c.metrics, agent.CacheTypeWorkload, float32(taintedKeyIDsCount))
}

func getKeyIDFromSVIDToken(svidToken string) (string, error) {
token, err := jwt.ParseSigned(svidToken, jwtsvid.AllowedSignatureAlgorithms)
if err != nil {
return "", fmt.Errorf("failed to parse JWT-SVID: %w", err)
}

if len(token.Headers) != 1 {
return "", fmt.Errorf("malformed JWT-SVID: expected a single token header; got %d", len(token.Headers))
}

keyID := token.Headers[0].KeyID
if keyID == "" {
return "", errors.New("missing key ID in token header of minted JWT-SVID")
}

return keyID, nil
}

func jwtSVIDKey(spiffeID spiffeid.ID, audience []string) string {
h := sha256.New()

Expand Down
8 changes: 5 additions & 3 deletions pkg/agent/manager/cache/jwt_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import (

func TestJWTSVIDCacheBasic(t *testing.T) {
now := time.Now()
expected := &client.JWTSVID{Token: "X", IssuedAt: now, ExpiresAt: now.Add(time.Second), Kid: "the-kid"}
tok := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg"
keyID := "dZDfYiw1uGzMwdMYHL7FEYyK8HOKKwLX"
expected := &client.JWTSVID{Token: tok, IssuedAt: now, ExpiresAt: now.Add(time.Second)}

fakeMetrics := fakemetrics.New()
log, hook := test.NewNullLogger()
Expand All @@ -40,7 +42,7 @@ func TestJWTSVIDCacheBasic(t *testing.T) {
assert.Equal(t, expected, actual)

// Remove tainted authority, should not be cached anymore
cache.TaintJWTSVIDs(context.Background(), map[string]struct{}{"the-kid": {}})
cache.TaintJWTSVIDs(context.Background(), map[string]struct{}{keyID: {}})
actual, ok = cache.GetJWTSVID(spiffeID, []string{"bar"})
assert.False(t, ok)
assert.Nil(t, actual)
Expand All @@ -52,7 +54,7 @@ func TestJWTSVIDCacheBasic(t *testing.T) {
Message: "JWT-SVIDs were removed from the JWT cache because they were issued by a tainted authority",
Data: logrus.Fields{
telemetry.CountJWTSVIDs: "1",
telemetry.JWTAuthorityKeyIDs: "the-kid",
telemetry.JWTAuthorityKeyIDs: keyID,
},
},
}
Expand Down
19 changes: 9 additions & 10 deletions pkg/agent/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1531,47 +1531,46 @@ func TestFetchJWTSVID(t *testing.T) {

now := clk.Now()
// fetch succeeds
tokA := "eyJhbGciOiJFUzI1NiIsImtpZCI6ImRaRGZZaXcxdUd6TXdkTVlITDdGRVl5SzhIT0tLd0xYIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjU3MzEsImlhdCI6MTcyNDI3OTQwNywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.dFr-oWhm5tK0bBuVXt-sGESM5l7hhoY-Gtt5DkuFoJL5Y9d4ZfmicCvUCjL4CqDB3BO_cPqmFfrO7H7pxQbGLg"
tokenA := "A"
issuedAtA := now.Unix()
expiresAtA := now.Add(time.Minute).Unix()
fetchResp.Svid = &types.JWTSVID{
Token: tokA,
Token: tokenA,
IssuedAt: issuedAtA,
ExpiresAt: expiresAtA,
}
svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience)
require.NoError(t, err)
require.Equal(t, tokA, svid.Token)
require.Equal(t, tokenA, svid.Token)
require.Equal(t, issuedAtA, svid.IssuedAt.Unix())
require.Equal(t, expiresAtA, svid.ExpiresAt.Unix())

// assert cached JWT is returned w/o trying to fetch (since cached version does not expire soon)
tokB := "eyJhbGciOiJFUzI1NiIsImtpZCI6InQ4ajd4cmdtVDVGRzNkbG45ZnFYanlNQzF0emxyT1B6IiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwNDksImlhdCI6MTcyNDI3OTY1OSwic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.68FB0ubhQsirJE_K0QSQxn_b06OZaWNrgC3nJCDupHEbSftXuGiAFDqUqK_HGMoKC1Nz9bxBNNQSoL50H4C3vw"
fetchResp.Svid = &types.JWTSVID{
Token: tokB,
Token: "B",
IssuedAt: now.Unix(),
ExpiresAt: now.Add(time.Minute).Unix(),
}
svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience)
require.NoError(t, err)
require.Equal(t, tokA, svid.Token)
require.Equal(t, tokenA, svid.Token)
require.Equal(t, issuedAtA, svid.IssuedAt.Unix())
require.Equal(t, expiresAtA, svid.ExpiresAt.Unix())

// expire the cached JWT soon and make sure new JWT is fetched
clk.Add(time.Second * 45)
now = clk.Now()
tokC := "eyJhbGciOiJFUzI1NiIsImtpZCI6InlueVV3TEg5bklOaWcyWEdRdTBkdklWbndieG5JeHlPIiwidHlwIjoiSldUIn0.eyJhdWQiOlsidGVzdC1hdWRpZW5jZSJdLCJleHAiOjE3MjQzNjYwOTIsImlhdCI6MTcyNDI3OTc0Nywic3ViIjoic3BpZmZlOi8vZXhhbXBsZS5vcmcvYWdlbnQvZGJ1c2VyIn0.feVAZPLmWT4ohzNKBoil90WBn64noqnCYuXXJjhItdtsOmJqMBm-blfJl4pBvbRBCWym2YaXK9gl9RoAfFG0zQ"
tokenC := "C"
issuedAtC := now.Unix()
expiresAtC := now.Add(time.Minute).Unix()
fetchResp.Svid = &types.JWTSVID{
Token: tokC,
Token: tokenC,
IssuedAt: issuedAtC,
ExpiresAt: expiresAtC,
}
svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience)
require.NoError(t, err)
require.Equal(t, tokC, svid.Token)
require.Equal(t, tokenC, svid.Token)
require.Equal(t, issuedAtC, svid.IssuedAt.Unix())
require.Equal(t, expiresAtC, svid.ExpiresAt.Unix())

Expand All @@ -1580,7 +1579,7 @@ func TestFetchJWTSVID(t *testing.T) {
fetchResp.Svid = nil
svid, err = m.FetchJWTSVID(context.Background(), regEntriesMap["resp2"][0], audience)
require.NoError(t, err)
require.Equal(t, tokC, svid.Token)
require.Equal(t, tokenC, svid.Token)
require.Equal(t, issuedAtC, svid.IssuedAt.Unix())
require.Equal(t, expiresAtC, svid.ExpiresAt.Unix())

Expand Down

0 comments on commit 444ede9

Please sign in to comment.