diff --git a/cmd/entrypoints/serve.go b/cmd/entrypoints/serve.go index 1325e95c3..905ca6417 100644 --- a/cmd/entrypoints/serve.go +++ b/cmd/entrypoints/serve.go @@ -41,6 +41,9 @@ import ( "google.golang.org/grpc/reflection" ) +const SixteenMegabytes int = 16777216 // 16 * 2^20 +const SixteenKibibytes uint32 = 16384 // 16 * 2^10 + var defaultCorsHeaders = []string{"Content-Type"} // serveCmd represents the serve command @@ -217,6 +220,7 @@ func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig) error { logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) httpServer, err := newHTTPServer(ctx, cfg, authContext, cfg.GetGrpcHostAddress(), grpc.WithInsecure(), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(SixteenMegabytes)), grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes)) if err != nil { return err diff --git a/pkg/auth/auth_context.go b/pkg/auth/auth_context.go index 95bc2c307..bed7ec070 100644 --- a/pkg/auth/auth_context.go +++ b/pkg/auth/auth_context.go @@ -190,3 +190,16 @@ func GetOauth2Config(options config.OAuthOptions) (oauth2.Config, error) { }, }, nil } + +func GetL5Oauth2Config(mainConfig *oauth2.Config) oauth2.Config { + return oauth2.Config{ + RedirectURL: "https://flyte-rs.av.lyft.net/callback", + ClientID: mainConfig.ClientID, + ClientSecret: mainConfig.ClientSecret, + Scopes: mainConfig.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: mainConfig.Endpoint.AuthURL, + TokenURL: mainConfig.Endpoint.TokenURL, + }, + } +} diff --git a/pkg/auth/handlers.go b/pkg/auth/handlers.go index c9a83386d..286bb4787 100644 --- a/pkg/auth/handlers.go +++ b/pkg/auth/handlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "time" "github.com/lyft/flyteadmin/pkg/audit" @@ -17,6 +18,7 @@ import ( "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/errors" "github.com/lyft/flytestdlib/logger" + "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -65,6 +67,7 @@ func RefreshTokensIfExists(ctx context.Context, authContext interfaces.Authentic } func GetLoginHandler(ctx context.Context, authContext interfaces.AuthenticationContext) http.HandlerFunc { + l5OauthConfig := GetL5Oauth2Config(authContext.OAuth2Config()) return func(writer http.ResponseWriter, request *http.Request) { csrfCookie := NewCsrfCookie() csrfToken := csrfCookie.Value @@ -74,6 +77,13 @@ func GetLoginHandler(ctx context.Context, authContext interfaces.AuthenticationC logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state) url := authContext.OAuth2Config().AuthCodeURL(state) queryParams := request.URL.Query() + + // Special hack for L5 to last til the end of Q1 + if strings.Contains(request.Host, "flyte-rs.av.lyft.net") { + logger.Debugf(ctx, "Changing the callback in the /authorize call to point to L5") + url = l5OauthConfig.AuthCodeURL(state) + } + if flowEndRedirectURL := queryParams.Get(RedirectURLParameter); flowEndRedirectURL != "" { redirectCookie := NewRedirectCookie(ctx, flowEndRedirectURL) if redirectCookie != nil { @@ -87,6 +97,7 @@ func GetLoginHandler(ctx context.Context, authContext interfaces.AuthenticationC } func GetCallbackHandler(ctx context.Context, authContext interfaces.AuthenticationContext) http.HandlerFunc { + l5OauthConfig := GetL5Oauth2Config(authContext.OAuth2Config()) return func(writer http.ResponseWriter, request *http.Request) { logger.Debugf(ctx, "Running callback handler...") authorizationCode := request.FormValue(AuthorizationResponseCodeType) @@ -98,11 +109,22 @@ func GetCallbackHandler(ctx context.Context, authContext interfaces.Authenticati return } - token, err := authContext.OAuth2Config().Exchange(ctx, authorizationCode) - if err != nil { - logger.Errorf(ctx, "Error when exchanging code %s", err) - writer.WriteHeader(http.StatusForbidden) - return + var token *oauth2.Token + // Additional hacks for L5 + if strings.Contains(request.Host, "flyte-rs.av.lyft.net") { + token, err = l5OauthConfig.Exchange(ctx, authorizationCode) + if err != nil { + logger.Errorf(ctx, "Error when exchanging code %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } + } else { + token, err = authContext.OAuth2Config().Exchange(ctx, authorizationCode) + if err != nil { + logger.Errorf(ctx, "Error when exchanging code %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } } err = authContext.CookieManager().SetTokenCookies(ctx, writer, token)