Skip to content

Commit

Permalink
support grpc oauth (#1980)
Browse files Browse the repository at this point in the history
Enables using OAuth2 for calls to the Flow API

![Screenshot 2024-09-09 at 17 36
13](https://github.com/user-attachments/assets/26f3d8ac-7314-421e-892a-1c5b0d255b9a)

Requires the following env vars:
- `PEERDB_OAUTH_ISSUER_URL` - This is the OAuth Issuer URL of the JWT
(Like `https://{AUTH0_DOMAIN}/`)
- `PEERDB_OAUTH_DISCOVERY_ENABLED` - Set this to `true` to enable
discovery via `/.well-known/jwks.json` endpoint defined in openID spec
- `PEERDB_OAUTH_KEYSET_JSON` - If custom json keyset is to be provided.
- `PEERDB_OAUTH_JWT_CLAIM_KEY`, `PEERDB_OAUTH_JWT_CLAIM_VALUE` - any
custom key-value to be additionally checked while validating the
incoming jwt

Health Endpoints are explicitly excluded from auth.
  • Loading branch information
serprex authored Sep 10, 2024
1 parent a9bb42b commit 97c66a2
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 5 deletions.
242 changes: 242 additions & 0 deletions flow/auth/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package auth

import (
"context"
"errors"
"fmt"
"log/slog"
"net/url"
"strings"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/PeerDB-io/peer-flow/peerdbenv"
)

//nolint:lll
type AuthenticationConfig struct {
OauthJwtCustomClaims map[string]string `json:"oauth_custom_claims" yaml:"oauth_custom_claims" mapstructure:"oauth_custom_claims"`
KeySetJSON string `json:"key_set_json" yaml:"key_set_json" mapstructure:"key_set_json"`
OAuthIssuerUrl string `json:"oauth_domain" yaml:"oauth_domain" mapstructure:"oauth_domain"`
Enabled bool `json:"enabled" yaml:"enabled" mapstructure:"enabled"`
OAuthDiscoveryEnabled bool `json:"oauth_discovery_enabled" yaml:"oauth_discovery_enabled" mapstructure:"oauth_discovery_enabled"`
}

type identityProvider struct {
keySet jwk.Set
validateOpt jwt.ValidateOption
issuer string
}

func AuthGrpcMiddleware(unauthenticatedMethods []string) ([]grpc.ServerOption, error) {
oauthConfig := peerdbenv.GetPeerDBOAuthConfig()
oauthJwtClaims := map[string]string{}
if oauthConfig.OAuthJwtClaimKey != "" {
oauthJwtClaims[oauthConfig.OAuthJwtClaimKey] = oauthConfig.OAuthClaimValue
}
cfg := AuthenticationConfig{
Enabled: oauthConfig.OAuthIssuerUrl != "",
KeySetJSON: oauthConfig.KeySetJson,
OAuthDiscoveryEnabled: oauthConfig.OAuthDiscoveryEnabled,
OAuthIssuerUrl: oauthConfig.OAuthIssuerUrl,
OauthJwtCustomClaims: oauthJwtClaims,
}
// load identity providers before checking if authentication is enabled so configuration can be validated
ip, err := identityProvidersFromConfig(cfg)

if !cfg.Enabled {
if err != nil { // if there was an error loading identity providers, warn only if authentication is disabled
slog.Warn("OAuth is disabled", slog.Any("error", err))
}

slog.Warn("authentication is disabled")

return nil, nil
}

if err != nil {
return nil, err
}

unauthenticatedMethodsMap := make(map[string]struct{}, len(unauthenticatedMethods))
for _, method := range unauthenticatedMethods {
unauthenticatedMethodsMap[method] = struct{}{}
}

return []grpc.ServerOption{
grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if _, unauthorized := unauthenticatedMethodsMap[info.FullMethod]; !unauthorized {
var authHeader string
authHeaders := metadata.ValueFromIncomingContext(ctx, "Authorization")
if len(authHeaders) == 1 {
authHeader = authHeaders[0]
} else if len(authHeaders) > 1 {
return nil, status.Errorf(codes.Unauthenticated, "multiple Authorization headers supplied, request rejected")
}
_, err := validateRequestToken(authHeader, cfg.OauthJwtCustomClaims, ip...)
if err != nil {
slog.Debug("failed to validate request token", slog.Any("error", err))
return nil, status.Errorf(codes.Unauthenticated, "%s", err.Error())
}
}
return handler(ctx, req)
}),
}, nil
}

func validateRequestToken(authHeader string, claims map[string]string, ip ...identityProvider) ([]byte, error) {
payload, err := jwtFromRequest(authHeader)
if err != nil {
return nil, fmt.Errorf("failed to parse authorization header: %w", err)
}

// We could simplify to jwt.Parse(payload, opts...), but it is ok for now
token, err := jwt.ParseInsecure(payload)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}

provider, err := identityProviderByToken(ip, token)
if err != nil {
return nil, err
}

validateOpts := identityProviderValidateOpts(provider)
if err := jwt.Validate(token, validateOpts...); err != nil {
return nil, fmt.Errorf("failed to validate token: %w", err)
}

if _, err := jws.Verify(payload, jws.WithKeySet(provider.keySet)); err != nil {
return nil, fmt.Errorf("failed to verify token: %w", err)
}

for key, value := range claims {
if token.PrivateClaims()[key] != value {
return nil, fmt.Errorf("token claim %s mismatch", key)
}
}

return payload, nil
}

// jwtFromRequest extracts the JWT token from the Authorization header.
// it truncates the "Bearer" prefix from the header value if exists.
func jwtFromRequest(authHeader string) ([]byte, error) {
if authHeader == "" {
return nil, errors.New("missing Authorization header")
}

return []byte(strings.TrimPrefix(authHeader, "Bearer ")), nil
}

func identityProviderValidateOpts(provider identityProvider) []jwt.ValidateOption {
validateOpts := []jwt.ValidateOption{
jwt.WithIssuer(provider.issuer),
jwt.WithValidator(jwt.IsExpirationValid()),
}

if provider.validateOpt != nil {
validateOpts = append(validateOpts, provider.validateOpt)
}
return validateOpts
}

func identityProviderByToken(ip []identityProvider, token jwt.Token) (identityProvider, error) {
var provider identityProvider
for _, p := range ip {
if p.issuer == token.Issuer() {
provider = p
break
}
}

if provider.issuer == "" {
return identityProvider{}, fmt.Errorf("identity provider for issuer %s not found", token.Issuer())
}
return provider, nil
}

type identityProviderResolver func(cfg AuthenticationConfig) (*identityProvider, error)

func identityProvidersFromConfig(cfg AuthenticationConfig) ([]identityProvider, error) {
resolvers := []identityProviderResolver{
keysetIdentityProvider,
openIdIdentityProvider,
}

ip := make([]identityProvider, 0, len(resolvers))
for _, resolver := range resolvers {
provider, err := resolver(cfg)
if err != nil {
return nil, err
}

if provider == nil {
continue
}

ip = append(ip, *provider)
}

if len(ip) == 0 {
return nil, errors.New("no identity providers configured")
}

return ip, nil
}

func openIdIdentityProvider(cfg AuthenticationConfig) (*identityProvider, error) {
if cfg.OAuthIssuerUrl == "" {
slog.Debug("OAuth Issuer Url not configured for identity provider")
return nil, nil
}
if !cfg.OAuthDiscoveryEnabled {
slog.Debug("OAuth discovery not enabled for identity provider")
return nil, nil
}
issuer := cfg.OAuthIssuerUrl
// This is a well known URL for jwks defined in OpenID discovery spec
jwksDiscoveryUrl, err := url.JoinPath(cfg.OAuthIssuerUrl, "/.well-known/jwks.json")
if err != nil {
return nil, err
}

cache := jwk.NewCache(context.Background())
if err := cache.Register(jwksDiscoveryUrl); err != nil {
return nil, fmt.Errorf("failed to register JWK key set from Discovery URL %s: %w", jwksDiscoveryUrl, err)
}
set := jwk.NewCachedSet(cache, jwksDiscoveryUrl)

slog.Info("JWK key set from Discovery Endpoint loaded", slog.String("jwks", jwksDiscoveryUrl), slog.Int("size", set.Len()))

return &identityProvider{
issuer: issuer,
keySet: set,
validateOpt: jwt.WithIssuer(issuer),
}, nil
}

func keysetIdentityProvider(cfg AuthenticationConfig) (*identityProvider, error) {
if cfg.KeySetJSON == "" {
slog.Debug("JWK key set JSON not configured for identity provider")
return nil, nil
}

set, err := jwk.ParseString(cfg.KeySetJSON)
if err != nil {
return nil, fmt.Errorf("failed to parse JWK key set from JSON: %w", err)
}

slog.Info("JWK key set from JSON loaded", slog.Int("size", set.Len()))

return &identityProvider{
issuer: cfg.OAuthIssuerUrl,
keySet: set,
}, nil
}
10 changes: 9 additions & 1 deletion flow/cmd/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"

"github.com/PeerDB-io/peer-flow/auth"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/logger"
"github.com/PeerDB-io/peer-flow/peerdbenv"
Expand Down Expand Up @@ -213,7 +214,14 @@ func APIMain(ctx context.Context, args *APIServerParams) error {
return fmt.Errorf("unable to create Temporal client: %w", err)
}

grpcServer := grpc.NewServer()
options, err := auth.AuthGrpcMiddleware([]string{
grpc_health_v1.Health_Check_FullMethodName,
grpc_health_v1.Health_Watch_FullMethodName,
})
if err != nil {
return err
}
grpcServer := grpc.NewServer(options...)

catalogPool, err := peerdbenv.GetCatalogConnectionPoolFromEnv(ctx)
if err != nil {
Expand Down
5 changes: 1 addition & 4 deletions flow/cmd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,7 @@ func (h *FlowRequestHandler) updateQRepConfigInCatalog(
ctx context.Context,
cfg *protos.QRepConfig,
) error {
var cfgBytes []byte
var err error

cfgBytes, err = proto.Marshal(cfg)
cfgBytes, err := proto.Marshal(cfg)
if err != nil {
return fmt.Errorf("unable to marshal qrep config: %w", err)
}
Expand Down
7 changes: 7 additions & 0 deletions flow/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ require (
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.9
github.com/lestrrat-go/jwx/v2 v2.1.1
github.com/lib/pq v1.10.9
github.com/linkedin/goavro/v2 v2.13.0
github.com/microsoft/go-mssqldb v1.7.2
Expand Down Expand Up @@ -93,6 +94,7 @@ require (
github.com/cockroachdb/redact v1.1.5 // indirect
github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect
github.com/danieljoos/wincred v1.2.2 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/dvsekhvalnov/jose2go v1.7.0 // indirect
github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
Expand All @@ -107,6 +109,11 @@ require (
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/mtibben/percent v0.2.1 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nexus-rpc/sdk-go v0.0.10 // indirect
Expand Down
14 changes: 14 additions & 0 deletions flow/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/djherbis/buffer v1.1.0/go.mod h1:VwN8VdFkMY0DCALdY8o00d3IZ6Amz/UNVMWcSaJT44o=
github.com/djherbis/buffer v1.2.0 h1:PH5Dd2ss0C7CRRhQCZ2u7MssF+No9ide8Ye71nPHcrQ=
github.com/djherbis/buffer v1.2.0/go.mod h1:fjnebbZjCUpPinBRD+TDwXSOeNQ7fPQWLfGQqiAiUyE=
Expand Down Expand Up @@ -313,6 +315,18 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.1.1 h1:Y2ltVl8J6izLYFs54BVcpXLv5msSW4o8eXwnzZLI32E=
github.com/lestrrat-go/jwx/v2 v2.1.1/go.mod h1:4LvZg7oxu6Q5VJwn7Mk/UwooNRnTHUpXBj2C4j3HNx0=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/linkedin/goavro/v2 v2.13.0 h1:L8eI8GcuciwUkt41Ej62joSZS4kKaYIUdze+6for9NU=
Expand Down
35 changes: 35 additions & 0 deletions flow/peerdbenv/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package peerdbenv

import "strconv"

type PeerDBOAuthConfig struct {
// there can be more complex use cases where domain != issuer, but we handle them later if required
OAuthIssuerUrl string `json:"oauth_issuer_url"`
KeySetJson string `json:"key_set_json"`
// This is a custom claim we may wish to validate (if needed)
OAuthJwtClaimKey string `json:"oauth_jwt_claim_key"`
OAuthClaimValue string `json:"oauth_jwt_claim_value"`
// Enabling uses /.well-known/ OpenID discovery endpoints, thus key-set etc. don't need to be specified
OAuthDiscoveryEnabled bool `json:"oauth_discovery_enabled"`
}

func GetPeerDBOAuthConfig() PeerDBOAuthConfig {
oauthIssuerUrl := GetEnvString("PEERDB_OAUTH_ISSUER_URL", "")
oauthDiscoveryEnabledString := GetEnvString("PEERDB_OAUTH_DISCOVERY_ENABLED", "false")
oauthDiscoveryEnabled, err := strconv.ParseBool(oauthDiscoveryEnabledString)
if err != nil {
oauthDiscoveryEnabled = false
}
oauthKeysetJson := GetEnvString("PEERDB_OAUTH_KEYSET_JSON", "")

oauthJwtClaimKey := GetEnvString("PEERDB_OAUTH_JWT_CLAIM_KEY", "")
oauthJwtClaimValue := GetEnvString("PEERDB_OAUTH_JWT_CLAIM_VALUE", "")

return PeerDBOAuthConfig{
OAuthIssuerUrl: oauthIssuerUrl,
OAuthDiscoveryEnabled: oauthDiscoveryEnabled,
KeySetJson: oauthKeysetJson,
OAuthJwtClaimKey: oauthJwtClaimKey,
OAuthClaimValue: oauthJwtClaimValue,
}
}

0 comments on commit 97c66a2

Please sign in to comment.