Skip to content

Commit

Permalink
Validate Azure join using JWT claims
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardjkim committed Jan 9, 2025
1 parent 5eee08d commit d5c1064
Show file tree
Hide file tree
Showing 3 changed files with 420 additions and 73 deletions.
2 changes: 1 addition & 1 deletion lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, a.clock.Now())
accessToken, err := makeToken(rsID, "", a.clock.Now())
require.NoError(t, err)

// add token to auth server
Expand Down
112 changes: 89 additions & 23 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
package auth

import (
"cmp"
"context"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"log/slog"
"net/url"
"slices"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/coreos/go-oidc"
"github.com/digitorus/pkcs7"
"github.com/go-jose/go-jose/v3/jwt"
Expand All @@ -44,7 +48,14 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

const azureAccessTokenAudience = "https://management.azure.com/"
const (
azureAccessTokenAudience = "https://management.azure.com/"

// azureUserAgent specifies the Azure User-Agent identification for telemetry.
azureUserAgent = "teleport"
// azureVirtualMachine specifies the Azure virtual machine resource type.
azureVirtualMachine = "virtualMachines"
)

// Structs for unmarshaling attested data. Schema can be found at
// https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2
Expand Down Expand Up @@ -77,9 +88,23 @@ type attestedData struct {

type accessTokenClaims struct {
jwt.Claims
ResourceID string `json:"xms_mirid"`
TenantID string `json:"tid"`
Version string `json:"ver"`
TenantID string `json:"tid"`
Version string `json:"ver"`

// Azure JWT tokens include two optional claims that can be used to validate
// the subscription and resource group of a joining node. These claims hold
// different values depending on the assigned Managed Identity of the Azure VM:
// - xms_mirid:
// - For System-Assigned Identity it represents the resource id of the VM.
// - For User-Assigned Identity it represents the resource id of the user-assigned identity.
// - xms_az_rid:
// - For System-Assigned Identity this claim is omitted.
// - For User-Assigned Identity it represents the resource id of the VM.
//
// More details at: https://learn.microsoft.com/en-us/answers/questions/1282788/existence-of-xms-az-rid-field-in-activity-logs-of

ManangedIdentityResourceID string `json:"xms_mirid"`
AzureResourceID string `json:"xms_az_rid"`
}

type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)
Expand Down Expand Up @@ -145,7 +170,14 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {
}
if cfg.getVMClient == nil {
cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) {
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil)
opts := &armpolicy.ClientOptions{
ClientOptions: policy.ClientOptions{
Telemetry: policy.TelemetryOptions{
ApplicationID: azureUserAgent,
},
},
}
client, err := azure.NewVirtualMachinesClient(subscriptionID, token, opts)
return client, trace.Wrap(err)
}
}
Expand Down Expand Up @@ -211,8 +243,16 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s
}

// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM.
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) {
// correct Azure VM. Returns the Aure join attributes
func verifyVMIdentity(
ctx context.Context,
cfg *azureRegisterConfig,
accessToken,
subscriptionID,
vmID string,
requestStart time.Time,
logger *slog.Logger,
) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) {
tokenClaims, err := cfg.verify(ctx, accessToken)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -240,6 +280,20 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

// Listing all VMs in an Azure subscription during the verification process
// is problematic when there are a large number of VMs in an Azure subscription.
// In some cases this can lead to throttling due to Azure API rate limits.
// To address the issue, the verification process will first attempt to
// parse required VM identifiers from the token claims. If this method fails,
// fallback to the original method of listing VMs and parsing the VM identifiers
// from the VM resource.
vmSubscription, vmResourceGroup, err := claimsToIdentifiers(tokenClaims)
if err == nil {
return azureJoinToAttrs(vmSubscription, vmResourceGroup), nil
}
logger.WarnContext(ctx, "Failed to parse VM identifiers from claims. Retrying with Azure VM API.",
"error", err)

tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
Expand All @@ -249,7 +303,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}

resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID)
resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -258,8 +312,8 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken

// If the token is from the system-assigned managed identity, the resource ID
// is for the VM itself and we can use it to look up the VM.
if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") {
vm, err = vmClient.Get(ctx, tokenClaims.ResourceID)
if slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -278,21 +332,35 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
return nil, trace.Wrap(err)
}
}
return azureJoinToAttrs(vm.Subscription, vm.ResourceGroup), nil
}

return vm, nil
// claimsToIdentifiers returns the vm identifiers from the provided claims.
func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) {
// xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity.
// The xms_mirid claim should be used instead.
rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID)
resourceID, err := arm.ParseResourceID(rid)
if err != nil {
return "", "", trace.Wrap(err, "failed to parse resource id from claims")
}
if !slices.Contains(resourceID.ResourceType.Types, azureVirtualMachine) {
return "", "", trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type)
}
return resourceID.SubscriptionID, resourceID.ResourceGroupName, nil
}

func checkAzureAllowRules(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
for _, rule := range allowRules {
if rule.Subscription != vm.Subscription {
func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzure, token *types.ProvisionTokenV2) error {
for _, rule := range token.Spec.Azure.Allow {
if rule.Subscription != attrs.Subscription {
continue
}
if !azureResourceGroupIsAllowed(rule.ResourceGroups, vm.ResourceGroup) {
if !azureResourceGroupIsAllowed(rule.ResourceGroups, attrs.ResourceGroup) {
continue
}
return nil
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token)
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName())
}
func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool {
if len(allowedResourceGroups) == 0 {
Expand All @@ -313,10 +381,10 @@ func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup
return false
}

func azureJoinToAttrs(vm *azure.VirtualMachine) *workloadidentityv1pb.JoinAttrsAzure {
func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv1pb.JoinAttrsAzure {
return &workloadidentityv1pb.JoinAttrsAzure{
Subscription: vm.Subscription,
ResourceGroup: vm.ResourceGroup,
Subscription: subscriptionID,
ResourceGroup: resourceGroupID,
}
}

Expand Down Expand Up @@ -345,13 +413,11 @@ func (a *Server) checkAzureRequest(
return nil, trace.Wrap(err)
}

vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart)
attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger)
if err != nil {
return nil, trace.Wrap(err)
}
attrs := azureJoinToAttrs(vm)

if err := checkAzureAllowRules(vm, token.GetName(), token.Spec.Azure.Allow); err != nil {
if err := checkAzureAllowRules(vmID, attrs, token); err != nil {
return attrs, trace.Wrap(err)
}

Expand Down
Loading

0 comments on commit d5c1064

Please sign in to comment.