From 97c66a29912c9350a0f57859940bf299b2429ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 10 Sep 2024 02:11:41 +0000 Subject: [PATCH] support grpc oauth (#1980) Enables using OAuth2 for calls to the Flow API ![Screenshot 2024-09-09 at 17 36 13](https://github.com/user-attachments/assets/26f3d8ac-7314-421e-892a-1c5b0d255b9a) Requires the following env vars: - `PEERDB_OAUTH_ISSUER_URL` - This is the OAuth Issuer URL of the JWT (Like `https://{AUTH0_DOMAIN}/`) - `PEERDB_OAUTH_DISCOVERY_ENABLED` - Set this to `true` to enable discovery via `/.well-known/jwks.json` endpoint defined in openID spec - `PEERDB_OAUTH_KEYSET_JSON` - If custom json keyset is to be provided. - `PEERDB_OAUTH_JWT_CLAIM_KEY`, `PEERDB_OAUTH_JWT_CLAIM_VALUE` - any custom key-value to be additionally checked while validating the incoming jwt Health Endpoints are explicitly excluded from auth. --- flow/auth/middleware.go | 242 ++++++++++++++++++++++++++++++++++++++++ flow/cmd/api.go | 10 +- flow/cmd/handler.go | 5 +- flow/go.mod | 7 ++ flow/go.sum | 14 +++ flow/peerdbenv/oauth.go | 35 ++++++ 6 files changed, 308 insertions(+), 5 deletions(-) create mode 100644 flow/auth/middleware.go create mode 100644 flow/peerdbenv/oauth.go diff --git a/flow/auth/middleware.go b/flow/auth/middleware.go new file mode 100644 index 0000000000..e3fdfabb7f --- /dev/null +++ b/flow/auth/middleware.go @@ -0,0 +1,242 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/url" + "strings" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/PeerDB-io/peer-flow/peerdbenv" +) + +//nolint:lll +type AuthenticationConfig struct { + OauthJwtCustomClaims map[string]string `json:"oauth_custom_claims" yaml:"oauth_custom_claims" mapstructure:"oauth_custom_claims"` + KeySetJSON string `json:"key_set_json" yaml:"key_set_json" mapstructure:"key_set_json"` + OAuthIssuerUrl string `json:"oauth_domain" yaml:"oauth_domain" mapstructure:"oauth_domain"` + Enabled bool `json:"enabled" yaml:"enabled" mapstructure:"enabled"` + OAuthDiscoveryEnabled bool `json:"oauth_discovery_enabled" yaml:"oauth_discovery_enabled" mapstructure:"oauth_discovery_enabled"` +} + +type identityProvider struct { + keySet jwk.Set + validateOpt jwt.ValidateOption + issuer string +} + +func AuthGrpcMiddleware(unauthenticatedMethods []string) ([]grpc.ServerOption, error) { + oauthConfig := peerdbenv.GetPeerDBOAuthConfig() + oauthJwtClaims := map[string]string{} + if oauthConfig.OAuthJwtClaimKey != "" { + oauthJwtClaims[oauthConfig.OAuthJwtClaimKey] = oauthConfig.OAuthClaimValue + } + cfg := AuthenticationConfig{ + Enabled: oauthConfig.OAuthIssuerUrl != "", + KeySetJSON: oauthConfig.KeySetJson, + OAuthDiscoveryEnabled: oauthConfig.OAuthDiscoveryEnabled, + OAuthIssuerUrl: oauthConfig.OAuthIssuerUrl, + OauthJwtCustomClaims: oauthJwtClaims, + } + // load identity providers before checking if authentication is enabled so configuration can be validated + ip, err := identityProvidersFromConfig(cfg) + + if !cfg.Enabled { + if err != nil { // if there was an error loading identity providers, warn only if authentication is disabled + slog.Warn("OAuth is disabled", slog.Any("error", err)) + } + + slog.Warn("authentication is disabled") + + return nil, nil + } + + if err != nil { + return nil, err + } + + unauthenticatedMethodsMap := make(map[string]struct{}, len(unauthenticatedMethods)) + for _, method := range unauthenticatedMethods { + unauthenticatedMethodsMap[method] = struct{}{} + } + + return []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if _, unauthorized := unauthenticatedMethodsMap[info.FullMethod]; !unauthorized { + var authHeader string + authHeaders := metadata.ValueFromIncomingContext(ctx, "Authorization") + if len(authHeaders) == 1 { + authHeader = authHeaders[0] + } else if len(authHeaders) > 1 { + return nil, status.Errorf(codes.Unauthenticated, "multiple Authorization headers supplied, request rejected") + } + _, err := validateRequestToken(authHeader, cfg.OauthJwtCustomClaims, ip...) + if err != nil { + slog.Debug("failed to validate request token", slog.Any("error", err)) + return nil, status.Errorf(codes.Unauthenticated, "%s", err.Error()) + } + } + return handler(ctx, req) + }), + }, nil +} + +func validateRequestToken(authHeader string, claims map[string]string, ip ...identityProvider) ([]byte, error) { + payload, err := jwtFromRequest(authHeader) + if err != nil { + return nil, fmt.Errorf("failed to parse authorization header: %w", err) + } + + // We could simplify to jwt.Parse(payload, opts...), but it is ok for now + token, err := jwt.ParseInsecure(payload) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + provider, err := identityProviderByToken(ip, token) + if err != nil { + return nil, err + } + + validateOpts := identityProviderValidateOpts(provider) + if err := jwt.Validate(token, validateOpts...); err != nil { + return nil, fmt.Errorf("failed to validate token: %w", err) + } + + if _, err := jws.Verify(payload, jws.WithKeySet(provider.keySet)); err != nil { + return nil, fmt.Errorf("failed to verify token: %w", err) + } + + for key, value := range claims { + if token.PrivateClaims()[key] != value { + return nil, fmt.Errorf("token claim %s mismatch", key) + } + } + + return payload, nil +} + +// jwtFromRequest extracts the JWT token from the Authorization header. +// it truncates the "Bearer" prefix from the header value if exists. +func jwtFromRequest(authHeader string) ([]byte, error) { + if authHeader == "" { + return nil, errors.New("missing Authorization header") + } + + return []byte(strings.TrimPrefix(authHeader, "Bearer ")), nil +} + +func identityProviderValidateOpts(provider identityProvider) []jwt.ValidateOption { + validateOpts := []jwt.ValidateOption{ + jwt.WithIssuer(provider.issuer), + jwt.WithValidator(jwt.IsExpirationValid()), + } + + if provider.validateOpt != nil { + validateOpts = append(validateOpts, provider.validateOpt) + } + return validateOpts +} + +func identityProviderByToken(ip []identityProvider, token jwt.Token) (identityProvider, error) { + var provider identityProvider + for _, p := range ip { + if p.issuer == token.Issuer() { + provider = p + break + } + } + + if provider.issuer == "" { + return identityProvider{}, fmt.Errorf("identity provider for issuer %s not found", token.Issuer()) + } + return provider, nil +} + +type identityProviderResolver func(cfg AuthenticationConfig) (*identityProvider, error) + +func identityProvidersFromConfig(cfg AuthenticationConfig) ([]identityProvider, error) { + resolvers := []identityProviderResolver{ + keysetIdentityProvider, + openIdIdentityProvider, + } + + ip := make([]identityProvider, 0, len(resolvers)) + for _, resolver := range resolvers { + provider, err := resolver(cfg) + if err != nil { + return nil, err + } + + if provider == nil { + continue + } + + ip = append(ip, *provider) + } + + if len(ip) == 0 { + return nil, errors.New("no identity providers configured") + } + + return ip, nil +} + +func openIdIdentityProvider(cfg AuthenticationConfig) (*identityProvider, error) { + if cfg.OAuthIssuerUrl == "" { + slog.Debug("OAuth Issuer Url not configured for identity provider") + return nil, nil + } + if !cfg.OAuthDiscoveryEnabled { + slog.Debug("OAuth discovery not enabled for identity provider") + return nil, nil + } + issuer := cfg.OAuthIssuerUrl + // This is a well known URL for jwks defined in OpenID discovery spec + jwksDiscoveryUrl, err := url.JoinPath(cfg.OAuthIssuerUrl, "/.well-known/jwks.json") + if err != nil { + return nil, err + } + + cache := jwk.NewCache(context.Background()) + if err := cache.Register(jwksDiscoveryUrl); err != nil { + return nil, fmt.Errorf("failed to register JWK key set from Discovery URL %s: %w", jwksDiscoveryUrl, err) + } + set := jwk.NewCachedSet(cache, jwksDiscoveryUrl) + + slog.Info("JWK key set from Discovery Endpoint loaded", slog.String("jwks", jwksDiscoveryUrl), slog.Int("size", set.Len())) + + return &identityProvider{ + issuer: issuer, + keySet: set, + validateOpt: jwt.WithIssuer(issuer), + }, nil +} + +func keysetIdentityProvider(cfg AuthenticationConfig) (*identityProvider, error) { + if cfg.KeySetJSON == "" { + slog.Debug("JWK key set JSON not configured for identity provider") + return nil, nil + } + + set, err := jwk.ParseString(cfg.KeySetJSON) + if err != nil { + return nil, fmt.Errorf("failed to parse JWK key set from JSON: %w", err) + } + + slog.Info("JWK key set from JSON loaded", slog.Int("size", set.Len())) + + return &identityProvider{ + issuer: cfg.OAuthIssuerUrl, + keySet: set, + }, nil +} diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 8936b75a01..a0530d9db8 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -23,6 +23,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" + "github.com/PeerDB-io/peer-flow/auth" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/peerdbenv" @@ -213,7 +214,14 @@ func APIMain(ctx context.Context, args *APIServerParams) error { return fmt.Errorf("unable to create Temporal client: %w", err) } - grpcServer := grpc.NewServer() + options, err := auth.AuthGrpcMiddleware([]string{ + grpc_health_v1.Health_Check_FullMethodName, + grpc_health_v1.Health_Watch_FullMethodName, + }) + if err != nil { + return err + } + grpcServer := grpc.NewServer(options...) catalogPool, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { diff --git a/flow/cmd/handler.go b/flow/cmd/handler.go index 5afad52b2d..9458fa4488 100644 --- a/flow/cmd/handler.go +++ b/flow/cmd/handler.go @@ -247,10 +247,7 @@ func (h *FlowRequestHandler) updateQRepConfigInCatalog( ctx context.Context, cfg *protos.QRepConfig, ) error { - var cfgBytes []byte - var err error - - cfgBytes, err = proto.Marshal(cfg) + cfgBytes, err := proto.Marshal(cfg) if err != nil { return fmt.Errorf("unable to marshal qrep config: %w", err) } diff --git a/flow/go.mod b/flow/go.mod index 9050a2ba71..54177874de 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -37,6 +37,7 @@ require ( github.com/jmoiron/sqlx v1.4.0 github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.9 + github.com/lestrrat-go/jwx/v2 v2.1.1 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 github.com/microsoft/go-mssqldb v1.7.2 @@ -93,6 +94,7 @@ require ( github.com/cockroachdb/redact v1.1.5 // indirect github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect github.com/danieljoos/wincred v1.2.2 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/dvsekhvalnov/jose2go v1.7.0 // indirect github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -107,6 +109,11 @@ require ( github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.6 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect github.com/mtibben/percent v0.2.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nexus-rpc/sdk-go v0.0.10 // indirect diff --git a/flow/go.sum b/flow/go.sum index 44f0e26a34..ea536002c1 100644 --- a/flow/go.sum +++ b/flow/go.sum @@ -156,6 +156,8 @@ github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/djherbis/buffer v1.1.0/go.mod h1:VwN8VdFkMY0DCALdY8o00d3IZ6Amz/UNVMWcSaJT44o= github.com/djherbis/buffer v1.2.0 h1:PH5Dd2ss0C7CRRhQCZ2u7MssF+No9ide8Ye71nPHcrQ= github.com/djherbis/buffer v1.2.0/go.mod h1:fjnebbZjCUpPinBRD+TDwXSOeNQ7fPQWLfGQqiAiUyE= @@ -313,6 +315,18 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k= +github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.1 h1:Y2ltVl8J6izLYFs54BVcpXLv5msSW4o8eXwnzZLI32E= +github.com/lestrrat-go/jwx/v2 v2.1.1/go.mod h1:4LvZg7oxu6Q5VJwn7Mk/UwooNRnTHUpXBj2C4j3HNx0= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/linkedin/goavro/v2 v2.13.0 h1:L8eI8GcuciwUkt41Ej62joSZS4kKaYIUdze+6for9NU= diff --git a/flow/peerdbenv/oauth.go b/flow/peerdbenv/oauth.go new file mode 100644 index 0000000000..cd76b30193 --- /dev/null +++ b/flow/peerdbenv/oauth.go @@ -0,0 +1,35 @@ +package peerdbenv + +import "strconv" + +type PeerDBOAuthConfig struct { + // there can be more complex use cases where domain != issuer, but we handle them later if required + OAuthIssuerUrl string `json:"oauth_issuer_url"` + KeySetJson string `json:"key_set_json"` + // This is a custom claim we may wish to validate (if needed) + OAuthJwtClaimKey string `json:"oauth_jwt_claim_key"` + OAuthClaimValue string `json:"oauth_jwt_claim_value"` + // Enabling uses /.well-known/ OpenID discovery endpoints, thus key-set etc. don't need to be specified + OAuthDiscoveryEnabled bool `json:"oauth_discovery_enabled"` +} + +func GetPeerDBOAuthConfig() PeerDBOAuthConfig { + oauthIssuerUrl := GetEnvString("PEERDB_OAUTH_ISSUER_URL", "") + oauthDiscoveryEnabledString := GetEnvString("PEERDB_OAUTH_DISCOVERY_ENABLED", "false") + oauthDiscoveryEnabled, err := strconv.ParseBool(oauthDiscoveryEnabledString) + if err != nil { + oauthDiscoveryEnabled = false + } + oauthKeysetJson := GetEnvString("PEERDB_OAUTH_KEYSET_JSON", "") + + oauthJwtClaimKey := GetEnvString("PEERDB_OAUTH_JWT_CLAIM_KEY", "") + oauthJwtClaimValue := GetEnvString("PEERDB_OAUTH_JWT_CLAIM_VALUE", "") + + return PeerDBOAuthConfig{ + OAuthIssuerUrl: oauthIssuerUrl, + OAuthDiscoveryEnabled: oauthDiscoveryEnabled, + KeySetJson: oauthKeysetJson, + OAuthJwtClaimKey: oauthJwtClaimKey, + OAuthClaimValue: oauthJwtClaimValue, + } +}