diff --git a/pkg/grpc/auth.go b/pkg/grpc/auth.go index 0ce673e..9e378f3 100644 --- a/pkg/grpc/auth.go +++ b/pkg/grpc/auth.go @@ -4,13 +4,14 @@ import ( "context" envoyAuth "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" "github.com/gogo/googleapis/google/rpc" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "net/http" ) -func CheckGRPCAuth(ctx context.Context, authClient envoyAuth.AuthorizationClient) error { +func checkGRPCAuth(ctx context.Context, authClient envoyAuth.AuthorizationClient) error { md, ok := metadata.FromIncomingContext(ctx) if !ok { return status.Errorf(codes.Unauthenticated, "missing metadata") @@ -42,3 +43,21 @@ func CheckGRPCAuth(ctx context.Context, authClient envoyAuth.AuthorizationClient return nil } + +func CheckGRPCAuthUnaryInterceptorWrapper(authClient envoyAuth.AuthorizationClient) func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + if err := checkGRPCAuth(ctx, authClient); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +func CheckGRPCAuthStreamInterceptorWrapper(authClient envoyAuth.AuthorizationClient) func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := checkGRPCAuth(ss.Context(), authClient); err != nil { + return err + } + return handler(srv, ss) + } +}