Skip to content

Commit

Permalink
Add clusterId param for isMfaRequired check in the web api.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Jan 23, 2025
1 parent 98067d2 commit add9cc9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
58 changes: 29 additions & 29 deletions lib/web/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package web

import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -193,39 +192,37 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
return nil, trace.Wrap(err)
}

var mfaRequiredCheck *proto.IsMFARequiredRequest
var mfaRequiredCheckProto *proto.IsMFARequiredRequest
if req.IsMFARequiredRequest != nil {
mfaRequiredCheckProto, err := h.checkAndGetProtoRequest(ctx, c, req.IsMFARequiredRequest)
mfaRequiredCheckProto, err = h.checkAndGetProtoRequest(ctx, c, req.IsMFARequiredRequest)
if err != nil {
return nil, trace.Wrap(err)
}

// If this is an app session request, check if mfa is required through the app's parent cluster.
// Otherwise, check within the challenge request.
if appParams := mfaRequiredCheckProto.GetApp(); appParams != nil {
if appParams.ClusterName != c.cfg.RootClusterName {
fmt.Printf("clusterName: %v\n", appParams.ClusterName)
site, err := h.getSiteByClusterName(ctx, c, appParams.ClusterName)
if err != nil {
return nil, trace.Wrap(err)
}

clusterClient, err := c.GetUserClient(ctx, site)
if err != nil {
return false, trace.Wrap(err)
}

res, err := clusterClient.IsMFARequired(ctx, mfaRequiredCheckProto)
if err != nil {
return false, trace.Wrap(err)
}

if !res.Required {
return &client.MFAAuthenticateChallenge{}, nil
}
// If the MFA requirement check is being performed for a leaf host, we must check directly
// with the leaf cluster before the authentication challenge request through root.
if req.IsMFARequiredRequest.ClusterID != "" && req.IsMFARequiredRequest.ClusterID != c.cfg.RootClusterName {
site, err := h.getSiteByClusterName(ctx, c, req.IsMFARequiredRequest.ClusterID)
if err != nil {
return nil, trace.Wrap(err)
}
} else {
mfaRequiredCheck = mfaRequiredCheckProto

clusterClient, err := c.GetUserClient(ctx, site)
if err != nil {
return false, trace.Wrap(err)
}

res, err := clusterClient.IsMFARequired(ctx, mfaRequiredCheckProto)
if err != nil {
return false, trace.Wrap(err)
}

if !res.Required {
return &client.MFAAuthenticateChallenge{}, nil
}

// We don't want to check again through the root cluster below.
mfaRequiredCheckProto = nil
}
}

Expand Down Expand Up @@ -254,7 +251,7 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
},
MFARequiredCheck: mfaRequiredCheck,
MFARequiredCheck: mfaRequiredCheckProto,
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope(req.ChallengeScope),
AllowReuse: allowReuse,
Expand Down Expand Up @@ -459,6 +456,9 @@ type isMFARequiredApp struct {
type isMFARequiredAdminAction struct{}

type isMFARequiredRequest struct {
// ClusterID is the ID of the cluster to check against for MFA requirements.
// If not set, MFA requirements will be checked against the root cluster.
ClusterID string `json:"clusterId,omitempty"`
// Database contains fields required to check if target database
// requires MFA check.
Database *isMFARequiredDatabase `json:"database,omitempty"`
Expand Down
1 change: 1 addition & 0 deletions web/packages/teleport/src/AppLauncher/AppLauncher.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export function AppLauncher() {
req: {
scope: MfaChallengeScope.USER_SESSION,
isMfaRequiredRequest: {
clusterId: pathParams.clusterId,
app: {
fqdn: pathParams.fqdn,
cluster_name: pathParams.clusterId,
Expand Down
3 changes: 1 addition & 2 deletions web/packages/teleport/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,7 @@ const cfg = {
return generatePath(cfg.api.connectionDiagnostic, { clusterId });
},

getMfaRequiredUrl() {
const clusterId = cfg.proxyCluster;
getMfaRequiredUrl(clusterId: string = cfg.proxyCluster) {
return generatePath(cfg.api.mfaRequired, { clusterId });
},

Expand Down
11 changes: 8 additions & 3 deletions web/packages/teleport/src/services/auth/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ function checkMfaRequired(
params: IsMfaRequiredRequest,
abortSignal?
): Promise<IsMfaRequiredResponse> {
return api.post(cfg.getMfaRequiredUrl(), params, abortSignal);
return api.post(cfg.getMfaRequiredUrl(params.clusterId), params, abortSignal);
}

function base64EncodeUnicode(str: string) {
Expand Down Expand Up @@ -508,13 +508,18 @@ function waitForMessage(

export default auth;

export type IsMfaRequiredRequest =
export type IsMfaRequiredRequest = {
// clusterId is the cluster to check mfa requirement against. When connecting to
// leaf hosts, this should be set to the leaf clusterId.
clusterId?: string;
} & (
| IsMfaRequiredDatabase
| IsMfaRequiredNode
| IsMfaRequiredKube
| IsMfaRequiredWindowsDesktop
| IsMfaRequiredApp
| IsMfaRequiredAdminAction;
| IsMfaRequiredAdminAction
);

export type IsMfaRequiredResponse = {
required: boolean;
Expand Down

0 comments on commit add9cc9

Please sign in to comment.