Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Fix Azure join method throttling #50929

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, a.clock.Now())
accessToken, err := makeToken(rsID, "", a.clock.Now())
require.NoError(t, err)

// add token to auth server
Expand Down
136 changes: 108 additions & 28 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package auth

import (
"cmp"
"context"
"crypto/x509"
"encoding/base64"
Expand All @@ -30,6 +31,8 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/coreos/go-oidc"
"github.com/digitorus/pkcs7"
"github.com/go-jose/go-jose/v3/jwt"
Expand All @@ -38,12 +41,20 @@ import (

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/utils"
)

const azureAccessTokenAudience = "https://management.azure.com/"
const (
azureAccessTokenAudience = "https://management.azure.com/"

// azureUserAgent specifies the Azure User-Agent identification for telemetry.
azureUserAgent = "teleport"
// azureVirtualMachine specifies the Azure virtual machine resource type.
azureVirtualMachine = "virtualMachines"
)

// Structs for unmarshaling attested data. Schema can be found at
// https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2
Expand Down Expand Up @@ -76,9 +87,23 @@ type attestedData struct {

type accessTokenClaims struct {
jwt.Claims
ResourceID string `json:"xms_mirid"`
TenantID string `json:"tid"`
Version string `json:"ver"`
TenantID string `json:"tid"`
Version string `json:"ver"`

// Azure JWT tokens include two optional claims that can be used to validate
// the subscription and resource group of a joining node. These claims hold
// different values depending on the assigned Managed Identity of the Azure VM:
// - xms_mirid:
// - For System-Assigned Identity it represents the resource id of the VM.
// - For User-Assigned Identity it represents the resource id of the user-assigned identity.
// - xms_az_rid:
// - For System-Assigned Identity this claim is omitted.
// - For User-Assigned Identity it represents the resource id of the VM.
//
// More details at: https://learn.microsoft.com/en-us/answers/questions/1282788/existence-of-xms-az-rid-field-in-activity-logs-of

ManangedIdentityResourceID string `json:"xms_mirid"`
AzureResourceID string `json:"xms_az_rid"`
}

type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)
Expand Down Expand Up @@ -144,7 +169,16 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
}
if cfg.getVMClient == nil {
cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil)
// The User-Agent is added for debugging purposes. It helps identify
// and isolate teleport traffic.
opts := &armpolicy.ClientOptions{
ClientOptions: policy.ClientOptions{
Telemetry: policy.TelemetryOptions{
ApplicationID: azureUserAgent,
},
},
}
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, opts)
return client, trace.Wrap(err)
}
}
Expand Down Expand Up @@ -210,8 +244,15 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s
}

// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM.
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) {
// correct Azure VM. Returns the Azure join attributes
func verifyVMIdentity(
ctx context.Context,
cfg *azureRegisterConfig,
accessToken,
subscriptionID,
vmID string,
requestStart time.Time,
) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) {
tokenClaims, err := cfg.verify(ctx, accessToken)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -239,6 +280,19 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

// Listing all VMs in an Azure subscription during the verification process
// is problematic when there are a large number of VMs in an Azure subscription.
// In some cases this can lead to throttling due to Azure API rate limits.
// To address the issue, the verification process will first attempt to
// parse required VM identifiers from the token claims. If this method fails,
// fallback to the original method of listing VMs and parsing the VM identifiers
// from the VM resource.
vmSubscription, vmResourceGroup, err := claimsToIdentifiers(tokenClaims)
if err == nil {
return azureJoinToAttrs(vmSubscription, vmResourceGroup), nil
}
log.WithError(err).Warn("Failed to parse VM identifiers from claims. Retrying with Azure VM API.")

tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
Expand All @@ -248,7 +302,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID)
resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -257,8 +311,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken

// If the token is from the system-assigned managed identity, the resource ID
// is for the VM itself and we can use it to look up the VM.
if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") {
vm, err = vmClient.Get(ctx, tokenClaims.ResourceID)
if slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -277,21 +331,35 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}
}
return azureJoinToAttrs(vm.Subscription, vm.ResourceGroup), nil
}

return vm, nil
// claimsToIdentifiers returns the vm identifiers from the provided claims.
func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) {
// xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity.
// The xms_mirid claim should be used instead.
rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID)
resourceID, err := arm.ParseResourceID(rid)
if err != nil {
return "", "", trace.Wrap(err, "failed to parse resource id from claims")
}
if !slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type)
}
return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil
}

func checkAzureAllowRules(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
for _, rule := range allowRules {
if rule.Subscription != vm.Subscription {
func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzure, token *types.ProvisionTokenV2) error {
for _, rule := range token.Spec.Azure.Allow {
if rule.Subscription != attrs.Subscription {
continue
}
if !azureResourceGroupIsAllowed(rule.ResourceGroups, vm.ResourceGroup) {
if !azureResourceGroupIsAllowed(rule.ResourceGroups, attrs.ResourceGroup) {
continue
}
return nil
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token)
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName())
}
func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool {
if len(allowedResourceGroups) == 0 {
Expand All @@ -312,37 +380,48 @@ func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup
return false
}

func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *proto.RegisterUsingAzureMethodRequest, cfg *azureRegisterConfig) error {
func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv1pb.JoinAttrsAzure {
return &workloadidentityv1pb.JoinAttrsAzure{
Subscription: subscriptionID,
ResourceGroup: resourceGroupID,
}
}

func (a *Server) checkAzureRequest(
ctx context.Context,
challenge string,
req *proto.RegisterUsingAzureMethodRequest,
cfg *azureRegisterConfig,
) (*workloadidentityv1pb.JoinAttrsAzure, error) {
requestStart := a.clock.Now()
tokenName := req.RegisterUsingTokenRequest.Token
provisionToken, err := a.GetToken(ctx, tokenName)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
if provisionToken.GetJoinMethod() != types.JoinMethodAzure {
return trace.AccessDenied("this token does not support the Azure join method")
return nil, trace.AccessDenied("this token does not support the Azure join method")
}

subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart)
attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart)
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}

token, ok := provisionToken.(*types.ProvisionTokenV2)
if !ok {
return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
}

if err := checkAzureAllowRules(vm, token.GetName(), token.Spec.Azure.Allow); err != nil {
return trace.Wrap(err)
if err := checkAzureAllowRules(vmID, attrs, token); err != nil {
return attrs, trace.Wrap(err)
}

return nil
return attrs, nil
}

func generateAzureChallenge() (string, error) {
Expand Down Expand Up @@ -399,7 +478,8 @@ func (a *Server) RegisterUsingAzureMethod(
return nil, trace.Wrap(err)
}

if err := a.checkAzureRequest(ctx, challenge, req, cfg); err != nil {
_, err = a.checkAzureRequest(ctx, challenge, req, cfg)
if err != nil {
return nil, trace.Wrap(err)
}

Expand Down
Loading
Loading