From 7048a648ba6a55f0b62c25c905e66e8ca68f5079 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 22 Jan 2025 19:05:22 -0800 Subject: [PATCH 1/2] Check whether mfa is required for an app session through the correct cluster. --- lib/web/apiserver.go | 18 +++++++++--------- lib/web/mfa.go | 43 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index d541467bf0582..a062381b642e2 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -4562,15 +4562,6 @@ func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http. // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) getSiteByParams(ctx context.Context, sctx *SessionContext, p httprouter.Params) (reversetunnelclient.RemoteSite, error) { clusterName := p.ByName("site") - if clusterName == currentSiteShortcut { - res, err := h.cfg.ProxyClient.GetClusterName() - if err != nil { - h.logger.WarnContext(ctx, "Failed to query cluster name", "error", err) - return nil, trace.Wrap(err) - } - clusterName = res.GetClusterName() - } - site, err := h.getSiteByClusterName(ctx, sctx, clusterName) if err != nil { return nil, trace.Wrap(err) @@ -4580,6 +4571,15 @@ func (h *Handler) getSiteByParams(ctx context.Context, sctx *SessionContext, p h } func (h *Handler) getSiteByClusterName(ctx context.Context, sctx *SessionContext, clusterName string) (reversetunnelclient.RemoteSite, error) { + if clusterName == currentSiteShortcut { + res, err := h.cfg.ProxyClient.GetClusterName() + if err != nil { + h.logger.WarnContext(ctx, "Failed to query cluster name", "error", err) + return nil, trace.Wrap(err) + } + clusterName = res.GetClusterName() + } + proxy, err := h.ProxyWithRoles(ctx, sctx) if err != nil { h.logger.WarnContext(ctx, "Failed to get proxy with roles", "error", err) diff --git a/lib/web/mfa.go b/lib/web/mfa.go index c59b0ae10cbd7..61b181307d4b4 100644 --- a/lib/web/mfa.go +++ b/lib/web/mfa.go @@ -20,6 +20,7 @@ package web import ( "context" + "fmt" "net/http" "net/url" "strings" @@ -180,6 +181,8 @@ type createAuthenticateChallengeRequest struct { // createAuthenticateChallengeHandle creates and returns MFA authentication challenges for the user in context (logged in user). // Used when users need to re-authenticate their second factors. func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext) (interface{}, error) { + ctx := r.Context() + var req createAuthenticateChallengeRequest if err := httplib.ReadResourceJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -190,12 +193,40 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht return nil, trace.Wrap(err) } - var mfaRequiredCheckProto *proto.IsMFARequiredRequest + var mfaRequiredCheck *proto.IsMFARequiredRequest if req.IsMFARequiredRequest != nil { - mfaRequiredCheckProto, err = h.checkAndGetProtoRequest(r.Context(), c, req.IsMFARequiredRequest) + mfaRequiredCheckProto, err := h.checkAndGetProtoRequest(ctx, c, req.IsMFARequiredRequest) if err != nil { return nil, trace.Wrap(err) } + + // If this is an app session request, check if mfa is required through the app's parent cluster. + // Otherwise, check within the challenge request. + if appParams := mfaRequiredCheckProto.GetApp(); appParams != nil { + if appParams.ClusterName != c.cfg.RootClusterName { + fmt.Printf("clusterName: %v\n", appParams.ClusterName) + site, err := h.getSiteByClusterName(ctx, c, appParams.ClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + clusterClient, err := c.GetUserClient(ctx, site) + if err != nil { + return false, trace.Wrap(err) + } + + res, err := clusterClient.IsMFARequired(ctx, mfaRequiredCheckProto) + if err != nil { + return false, trace.Wrap(err) + } + + if !res.Required { + return &client.MFAAuthenticateChallenge{}, nil + } + } + } else { + mfaRequiredCheck = mfaRequiredCheckProto + } } allowReuse := mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_NO @@ -219,11 +250,11 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht query.Set("channel_id", channelID) ssoClientRedirectURL.RawQuery = query.Encode() - chal, err := clt.CreateAuthenticateChallenge(r.Context(), &proto.CreateAuthenticateChallengeRequest{ + chal, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ ContextUser: &proto.ContextUser{}, }, - MFARequiredCheck: mfaRequiredCheckProto, + MFARequiredCheck: mfaRequiredCheck, ChallengeExtensions: &mfav1.ChallengeExtensions{ Scope: mfav1.ChallengeScope(req.ChallengeScope), AllowReuse: allowReuse, @@ -533,7 +564,9 @@ func (h *Handler) checkAndGetProtoRequest(ctx context.Context, scx *SessionConte protoReq = &proto.IsMFARequiredRequest{ Target: &proto.IsMFARequiredRequest_App{ App: &proto.RouteToApp{ - Name: resolvedApp.App.GetName(), + Name: resolvedApp.App.GetName(), + PublicAddr: resolvedApp.App.GetPublicAddr(), + ClusterName: resolvedApp.ClusterName, }, }, } From e2d0ae8761106948f670fb2ce665b402cfaa183b Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 23 Jan 2025 11:42:33 -0800 Subject: [PATCH 2/2] Add clusterId param for isMfaRequired check in the web api. --- lib/web/mfa.go | 58 +++++++++---------- .../teleport/src/AppLauncher/AppLauncher.tsx | 1 + web/packages/teleport/src/config.ts | 6 +- .../teleport/src/services/auth/auth.ts | 11 +++- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/lib/web/mfa.go b/lib/web/mfa.go index 61b181307d4b4..6326978bb7dd3 100644 --- a/lib/web/mfa.go +++ b/lib/web/mfa.go @@ -20,7 +20,6 @@ package web import ( "context" - "fmt" "net/http" "net/url" "strings" @@ -193,39 +192,37 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht return nil, trace.Wrap(err) } - var mfaRequiredCheck *proto.IsMFARequiredRequest + var mfaRequiredCheckProto *proto.IsMFARequiredRequest if req.IsMFARequiredRequest != nil { - mfaRequiredCheckProto, err := h.checkAndGetProtoRequest(ctx, c, req.IsMFARequiredRequest) + mfaRequiredCheckProto, err = h.checkAndGetProtoRequest(ctx, c, req.IsMFARequiredRequest) if err != nil { return nil, trace.Wrap(err) } - // If this is an app session request, check if mfa is required through the app's parent cluster. - // Otherwise, check within the challenge request. - if appParams := mfaRequiredCheckProto.GetApp(); appParams != nil { - if appParams.ClusterName != c.cfg.RootClusterName { - fmt.Printf("clusterName: %v\n", appParams.ClusterName) - site, err := h.getSiteByClusterName(ctx, c, appParams.ClusterName) - if err != nil { - return nil, trace.Wrap(err) - } - - clusterClient, err := c.GetUserClient(ctx, site) - if err != nil { - return false, trace.Wrap(err) - } - - res, err := clusterClient.IsMFARequired(ctx, mfaRequiredCheckProto) - if err != nil { - return false, trace.Wrap(err) - } - - if !res.Required { - return &client.MFAAuthenticateChallenge{}, nil - } + // If the MFA requirement check is being performed for a leaf host, we must check directly + // with the leaf cluster before the authentication challenge request through root. + if req.IsMFARequiredRequest.ClusterID != "" && req.IsMFARequiredRequest.ClusterID != c.cfg.RootClusterName { + site, err := h.getSiteByClusterName(ctx, c, req.IsMFARequiredRequest.ClusterID) + if err != nil { + return nil, trace.Wrap(err) } - } else { - mfaRequiredCheck = mfaRequiredCheckProto + + clusterClient, err := c.GetUserClient(ctx, site) + if err != nil { + return false, trace.Wrap(err) + } + + res, err := clusterClient.IsMFARequired(ctx, mfaRequiredCheckProto) + if err != nil { + return false, trace.Wrap(err) + } + + if !res.Required { + return &client.MFAAuthenticateChallenge{}, nil + } + + // We don't want to check again through the root cluster below. + mfaRequiredCheckProto = nil } } @@ -254,7 +251,7 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ ContextUser: &proto.ContextUser{}, }, - MFARequiredCheck: mfaRequiredCheck, + MFARequiredCheck: mfaRequiredCheckProto, ChallengeExtensions: &mfav1.ChallengeExtensions{ Scope: mfav1.ChallengeScope(req.ChallengeScope), AllowReuse: allowReuse, @@ -459,6 +456,9 @@ type isMFARequiredApp struct { type isMFARequiredAdminAction struct{} type isMFARequiredRequest struct { + // ClusterID is the ID of the cluster to check against for MFA requirements. + // If not set, MFA requirements will be checked against the root cluster. + ClusterID string `json:"clusterId,omitempty"` // Database contains fields required to check if target database // requires MFA check. Database *isMFARequiredDatabase `json:"database,omitempty"` diff --git a/web/packages/teleport/src/AppLauncher/AppLauncher.tsx b/web/packages/teleport/src/AppLauncher/AppLauncher.tsx index 8288b1d0cdc29..f84ad39547773 100644 --- a/web/packages/teleport/src/AppLauncher/AppLauncher.tsx +++ b/web/packages/teleport/src/AppLauncher/AppLauncher.tsx @@ -41,6 +41,7 @@ export function AppLauncher() { req: { scope: MfaChallengeScope.USER_SESSION, isMfaRequiredRequest: { + clusterId: pathParams.clusterId, app: { fqdn: pathParams.fqdn, cluster_name: pathParams.clusterId, diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index a5d3661efee56..5580f7d190dfe 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -749,8 +749,10 @@ const cfg = { return generatePath(cfg.api.connectionDiagnostic, { clusterId }); }, - getMfaRequiredUrl() { - const clusterId = cfg.proxyCluster; + getMfaRequiredUrl(clusterId?: string) { + if (!clusterId) { + clusterId = cfg.proxyCluster; + } return generatePath(cfg.api.mfaRequired, { clusterId }); }, diff --git a/web/packages/teleport/src/services/auth/auth.ts b/web/packages/teleport/src/services/auth/auth.ts index 100259d6dfc20..7d20ef5e00f3c 100644 --- a/web/packages/teleport/src/services/auth/auth.ts +++ b/web/packages/teleport/src/services/auth/auth.ts @@ -470,7 +470,7 @@ function checkMfaRequired( params: IsMfaRequiredRequest, abortSignal? ): Promise { - return api.post(cfg.getMfaRequiredUrl(), params, abortSignal); + return api.post(cfg.getMfaRequiredUrl(params.clusterId), params, abortSignal); } function base64EncodeUnicode(str: string) { @@ -508,13 +508,18 @@ function waitForMessage( export default auth; -export type IsMfaRequiredRequest = +export type IsMfaRequiredRequest = { + // clusterId is the cluster to check mfa requirement against. When connecting to + // leaf hosts, this should be set to the leaf clusterId. + clusterId?: string; +} & ( | IsMfaRequiredDatabase | IsMfaRequiredNode | IsMfaRequiredKube | IsMfaRequiredWindowsDesktop | IsMfaRequiredApp - | IsMfaRequiredAdminAction; + | IsMfaRequiredAdminAction +); export type IsMfaRequiredResponse = { required: boolean;