From 32df16f4daa51ffc3e8d4242df47e0a46dd55c92 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 13 Dec 2024 17:58:27 -0800 Subject: [PATCH] Fallback to vm validation --- lib/auth/join_azure.go | 111 +++++++++++++++++++++++++++++------- lib/auth/join_azure_test.go | 59 ++++++++++--------- 2 files changed, 119 insertions(+), 51 deletions(-) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 3064df724c567..1c923bfb452a7 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -28,6 +28,7 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/coreos/go-oidc" "github.com/digitorus/pkcs7" @@ -166,44 +167,45 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { type azureRegisterOption func(cfg *azureRegisterConfig) -// verifyAttestedData verifies that an attested data document was signed -// by Azure. -func verifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) error { +// parseAndVeryAttestedData verifies that an attested data document was signed +// by Azure. If verification is successful, it returns the ID of the VM that +// produced the document. +func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) { var signedAD signedAttestedData if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { - return trace.Wrap(err) + return "", "", trace.Wrap(err) } if signedAD.Encoding != "pkcs7" { - return trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) + return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) } sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----" sigBER, _ := pem.Decode([]byte(sigPEM)) if sigBER == nil { - return trace.AccessDenied("unable to decode attested data document") + return "", "", trace.AccessDenied("unable to decode attested data document") } p7, err := pkcs7.Parse(sigBER.Bytes) if err != nil { - return trace.Wrap(err) + return "", "", trace.Wrap(err) } var ad attestedData if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { - return trace.Wrap(err) + return "", "", trace.Wrap(err) } if ad.Nonce != challenge { - return trace.AccessDenied("challenge is missing or does not match") + return "", "", trace.AccessDenied("challenge is missing or does not match") } if len(p7.Certificates) == 0 { - return trace.AccessDenied("no certificates for signature") + return "", "", trace.AccessDenied("no certificates for signature") } fixAzureSigningAlgorithm(p7) // Azure only sends the leaf cert, so we have to fetch the intermediate. intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0]) if err != nil { - return trace.Wrap(err) + return "", "", trace.Wrap(err) } if intermediate != nil { p7.Certificates = append(p7.Certificates, intermediate) @@ -215,10 +217,10 @@ func verifyAttestedData(ctx context.Context, adBytes []byte, challenge string, c } if err := p7.VerifyWithChain(pool); err != nil { - return trace.Wrap(err) + return "", "", trace.Wrap(err) } - return nil + return ad.SubscriptionID, ad.ID, nil } // verifyToken verifies the token and validates the expected claims. @@ -253,7 +255,52 @@ func verifyToken(ctx context.Context, cfg *azureRegisterConfig, accessToken stri return tokenClaims, nil } -func checkAzureAllowRules(claims *accessTokenClaims, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error { +// verifyVMIdentity verifies that the provided access token came from the +// correct Azure VM. +func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, tokenClaims *accessTokenClaims, accessToken, subscriptionID, vmID string) (*azure.VirtualMachine, error) { + tokenCredential := azure.NewStaticCredential(azcore.AccessToken{ + Token: accessToken, + ExpiresOn: tokenClaims.Expiry.Time(), + }) + vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential) + if err != nil { + return nil, trace.Wrap(err) + } + + resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID) + if err != nil { + return nil, trace.Wrap(err) + } + + var vm *azure.VirtualMachine + + // If the token is from the system-assigned managed identity, the resource ID + // is for the VM itself and we can use it to look up the VM. + if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") { + vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID) + if err != nil { + return nil, trace.Wrap(err) + } + if vm.VMID != vmID { + return nil, trace.AccessDenied("vm ID does not match") + } + + // If the token is from a user-assigned managed identity, the resource ID is + // for the identity and we need to look the VM up by VM ID. + } else { + vm, err = vmClient.GetByVMID(ctx, vmID) + if err != nil { + if trace.IsNotFound(err) { + return nil, trace.AccessDenied("no VM found with matching VM ID") + } + return nil, trace.Wrap(err) + } + } + + return vm, nil +} + +func checkAzureAllowRulesWithClaims(claims *accessTokenClaims, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error { rid := claims.AzureResourceID if rid == "" { // xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity. @@ -270,17 +317,32 @@ func checkAzureAllowRules(claims *accessTokenClaims, token string, allowRules [] return trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type) } + if err := checkAzureAllowRules(resourceID.SubscriptionID, resourceID.ResourceGroupName, allowRules); err != nil { + return trace.AccessDenied("instance %v did not match any allow rules in token %v", resourceID.Name, token) + } + return nil +} + +func checkAzureAllowRulesWithVMs(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error { + if err := checkAzureAllowRules(vm.Subscription, vm.ResourceGroup, allowRules); err != nil { + return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token) + } + return nil +} + +func checkAzureAllowRules(subscription, resourceGroup string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error { for _, rule := range allowRules { - if rule.Subscription != resourceID.SubscriptionID { + if rule.Subscription != subscription { continue } - if !azureResourceGroupIsAllowed(rule.ResourceGroups, resourceID.ResourceGroupName) { + if !azureResourceGroupIsAllowed(rule.ResourceGroups, resourceGroup) { continue } return nil } - return trace.AccessDenied("instance %v did not match any allow rules in token %v", resourceID.Name, token) + return trace.AccessDenied("matching allow rule not found") } + func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool { if len(allowedResourceGroups) == 0 { return true @@ -311,7 +373,7 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p return trace.AccessDenied("this token does not support the Azure join method") } - err = verifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) + subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) if err != nil { return trace.Wrap(err) } @@ -326,11 +388,18 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) } - if err := checkAzureAllowRules(claims, token.GetName(), token.Spec.Azure.Allow); err != nil { - return trace.Wrap(err) + if err := checkAzureAllowRulesWithClaims(claims, token.GetName(), token.Spec.Azure.Allow); err == nil { + return nil } - return nil + // Required claims for validation are only present for source resource types + // that have onboarded to SNI auth. Fallback to validation with VMs if + // unable to validate with claims. + vm, err := verifyVMIdentity(ctx, cfg, claims, req.AccessToken, subID, vmID) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(checkAzureAllowRulesWithVMs(vm, token.GetName(), token.Spec.Azure.Allow)) } func generateAzureChallenge() (string, error) { diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 3ef911f7707aa..38be466401a02 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -386,16 +386,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "system-managed identity ok", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "system-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName), + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: defaultSubscription, - ResourceGroups: []string{defaultResourceGroup}, + Subscription: "system-managed-test", + ResourceGroups: []string{"system-managed-test"}, }, }, }, @@ -408,16 +408,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "system-managed identity with wrong subscription", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "system-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName), + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { Subscription: defaultSubscription, - ResourceGroups: []string{defaultResourceGroup}, + ResourceGroups: []string{"system-managed-test"}, }, }, }, @@ -425,20 +425,20 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + assertError: require.Error, }, { name: "system-managed identity with wrong resource group", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "system-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName), + tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: defaultSubscription, + Subscription: "system-managed-test", ResourceGroups: []string{defaultResourceGroup}, }, }, @@ -447,22 +447,22 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + assertError: require.Error, }, { name: "user-managed identity ok", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "user-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName), + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: defaultSubscription, - ResourceGroups: []string{defaultResourceGroup}, + Subscription: "user-managed-test", + ResourceGroups: []string{"user-managed-test"}, }, }, }, @@ -475,17 +475,17 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "user-managed identity with wrong subscription", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "user-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName), + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { Subscription: defaultSubscription, - ResourceGroups: []string{defaultResourceGroup}, + ResourceGroups: []string{"user-managed-test"}, }, }, }, @@ -493,21 +493,21 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + assertError: require.Error, }, { name: "user-managed identity with wrong resource group", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, + tokenSubscription: "user-managed-test", tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName), + tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName), + tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: defaultSubscription, + Subscription: "user-managed-test", ResourceGroups: []string{defaultResourceGroup}, }, }, @@ -516,15 +516,14 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + assertError: require.Error, }, { - name: "invalid resource type", + name: "get vm from identity", requestTokenName: "test-token", tokenSubscription: defaultSubscription, tokenVMID: defaultVMID, tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -539,7 +538,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isBadParameter, + assertError: require.NoError, }, }