Skip to content

Commit

Permalink
Fallback to vm validation
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardjkim committed Dec 14, 2024
1 parent f48d229 commit 32df16f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 51 deletions.
111 changes: 90 additions & 21 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/coreos/go-oidc"
"github.com/digitorus/pkcs7"
Expand Down Expand Up @@ -166,44 +167,45 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error {

type azureRegisterOption func(cfg *azureRegisterConfig)

// verifyAttestedData verifies that an attested data document was signed
// by Azure.
func verifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) error {
// parseAndVeryAttestedData verifies that an attested data document was signed
// by Azure. If verification is successful, it returns the ID of the VM that
// produced the document.
func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) {
var signedAD signedAttestedData
if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil {
return trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if signedAD.Encoding != "pkcs7" {
return trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding)
}

sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----"
sigBER, _ := pem.Decode([]byte(sigPEM))
if sigBER == nil {
return trace.AccessDenied("unable to decode attested data document")
return "", "", trace.AccessDenied("unable to decode attested data document")
}

p7, err := pkcs7.Parse(sigBER.Bytes)
if err != nil {
return trace.Wrap(err)
return "", "", trace.Wrap(err)
}
var ad attestedData
if err := utils.FastUnmarshal(p7.Content, &ad); err != nil {
return trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if ad.Nonce != challenge {
return trace.AccessDenied("challenge is missing or does not match")
return "", "", trace.AccessDenied("challenge is missing or does not match")
}

if len(p7.Certificates) == 0 {
return trace.AccessDenied("no certificates for signature")
return "", "", trace.AccessDenied("no certificates for signature")
}
fixAzureSigningAlgorithm(p7)

// Azure only sends the leaf cert, so we have to fetch the intermediate.
intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0])
if err != nil {
return trace.Wrap(err)
return "", "", trace.Wrap(err)
}
if intermediate != nil {
p7.Certificates = append(p7.Certificates, intermediate)
Expand All @@ -215,10 +217,10 @@ func verifyAttestedData(ctx context.Context, adBytes []byte, challenge string, c
}

if err := p7.VerifyWithChain(pool); err != nil {
return trace.Wrap(err)
return "", "", trace.Wrap(err)
}

return nil
return ad.SubscriptionID, ad.ID, nil
}

// verifyToken verifies the token and validates the expected claims.
Expand Down Expand Up @@ -253,7 +255,52 @@ func verifyToken(ctx context.Context, cfg *azureRegisterConfig, accessToken stri
return tokenClaims, nil
}

func checkAzureAllowRules(claims *accessTokenClaims, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
// verifyVMIdentity verifies that the provided access token came from the
// correct Azure VM.
func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, tokenClaims *accessTokenClaims, accessToken, subscriptionID, vmID string) (*azure.VirtualMachine, error) {
tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
Token: accessToken,
ExpiresOn: tokenClaims.Expiry.Time(),
})
vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential)
if err != nil {
return nil, trace.Wrap(err)
}

resourceID, err := arm.ParseResourceID(tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}

var vm *azure.VirtualMachine

// If the token is from the system-assigned managed identity, the resource ID
// is for the VM itself and we can use it to look up the VM.
if slices.Contains(resourceID.ResourceType.Types, "virtualMachines") {
vm, err = vmClient.Get(ctx, tokenClaims.ManangedIdentityResourceID)
if err != nil {
return nil, trace.Wrap(err)
}
if vm.VMID != vmID {
return nil, trace.AccessDenied("vm ID does not match")
}

// If the token is from a user-assigned managed identity, the resource ID is
// for the identity and we need to look the VM up by VM ID.
} else {
vm, err = vmClient.GetByVMID(ctx, vmID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("no VM found with matching VM ID")
}
return nil, trace.Wrap(err)
}
}

return vm, nil
}

func checkAzureAllowRulesWithClaims(claims *accessTokenClaims, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
rid := claims.AzureResourceID
if rid == "" {
// xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity.
Expand All @@ -270,17 +317,32 @@ func checkAzureAllowRules(claims *accessTokenClaims, token string, allowRules []
return trace.BadParameter("unexpected resource type: %q", resourceID.ResourceType.Type)
}

if err := checkAzureAllowRules(resourceID.SubscriptionID, resourceID.ResourceGroupName, allowRules); err != nil {
return trace.AccessDenied("instance %v did not match any allow rules in token %v", resourceID.Name, token)
}
return nil
}

func checkAzureAllowRulesWithVMs(vm *azure.VirtualMachine, token string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
if err := checkAzureAllowRules(vm.Subscription, vm.ResourceGroup, allowRules); err != nil {
return trace.AccessDenied("instance %v did not match any allow rules in token %v", vm.Name, token)
}
return nil
}

func checkAzureAllowRules(subscription, resourceGroup string, allowRules []*types.ProvisionTokenSpecV2Azure_Rule) error {
for _, rule := range allowRules {
if rule.Subscription != resourceID.SubscriptionID {
if rule.Subscription != subscription {
continue
}
if !azureResourceGroupIsAllowed(rule.ResourceGroups, resourceID.ResourceGroupName) {
if !azureResourceGroupIsAllowed(rule.ResourceGroups, resourceGroup) {
continue
}
return nil
}
return trace.AccessDenied("instance %v did not match any allow rules in token %v", resourceID.Name, token)
return trace.AccessDenied("matching allow rule not found")
}

func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool {
if len(allowedResourceGroups) == 0 {
return true
Expand Down Expand Up @@ -311,7 +373,7 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p
return trace.AccessDenied("this token does not support the Azure join method")
}

err = verifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -326,11 +388,18 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p
return trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken)
}

if err := checkAzureAllowRules(claims, token.GetName(), token.Spec.Azure.Allow); err != nil {
return trace.Wrap(err)
if err := checkAzureAllowRulesWithClaims(claims, token.GetName(), token.Spec.Azure.Allow); err == nil {
return nil
}

return nil
// Required claims for validation are only present for source resource types
// that have onboarded to SNI auth. Fallback to validation with VMs if
// unable to validate with claims.
vm, err := verifyVMIdentity(ctx, cfg, claims, req.AccessToken, subID, vmID)
if err != nil {
return trace.Wrap(err)
}
return trace.Wrap(checkAzureAllowRulesWithVMs(vm, token.GetName(), token.Spec.Azure.Allow))
}

func generateAzureChallenge() (string, error) {
Expand Down
59 changes: 29 additions & 30 deletions lib/auth/join_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,16 +386,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
{
name: "system-managed identity ok",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "system-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName),
tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
ResourceGroups: []string{defaultResourceGroup},
Subscription: "system-managed-test",
ResourceGroups: []string{"system-managed-test"},
},
},
},
Expand All @@ -408,37 +408,37 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
{
name: "system-managed identity with wrong subscription",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "system-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName),
tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
ResourceGroups: []string{defaultResourceGroup},
ResourceGroups: []string{"system-managed-test"},
},
},
},
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
certs: []*x509.Certificate{tlsConfig.Certificate},
assertError: isAccessDenied,
assertError: require.Error,
},
{
name: "system-managed identity with wrong resource group",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "system-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName),
tokenManagedIdentityResourceID: vmResourceID("system-managed-test", "system-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
Subscription: "system-managed-test",
ResourceGroups: []string{defaultResourceGroup},
},
},
Expand All @@ -447,22 +447,22 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
},
verify: mockVerifyToken(nil),
certs: []*x509.Certificate{tlsConfig.Certificate},
assertError: isAccessDenied,
assertError: require.Error,
},
{
name: "user-managed identity ok",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "user-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName),
tokenAzureResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName),
tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName),
tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
ResourceGroups: []string{defaultResourceGroup},
Subscription: "user-managed-test",
ResourceGroups: []string{"user-managed-test"},
},
},
},
Expand All @@ -475,39 +475,39 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
{
name: "user-managed identity with wrong subscription",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "user-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName),
tokenAzureResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName),
tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName),
tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
ResourceGroups: []string{defaultResourceGroup},
ResourceGroups: []string{"user-managed-test"},
},
},
},
JoinMethod: types.JoinMethodAzure,
},
verify: mockVerifyToken(nil),
certs: []*x509.Certificate{tlsConfig.Certificate},
assertError: isAccessDenied,
assertError: require.Error,
},
{
name: "user-managed identity with wrong resource group",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenSubscription: "user-managed-test",
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName),
tokenAzureResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName),
tokenManagedIdentityResourceID: identityResourceID("user-managed-test", "user-managed-test", defaultIdentityName),
tokenAzureResourceID: vmResourceID("user-managed-test", "user-managed-test", defaultVMName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Allow: []*types.ProvisionTokenSpecV2Azure_Rule{
{
Subscription: defaultSubscription,
Subscription: "user-managed-test",
ResourceGroups: []string{defaultResourceGroup},
},
},
Expand All @@ -516,15 +516,14 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
},
verify: mockVerifyToken(nil),
certs: []*x509.Certificate{tlsConfig.Certificate},
assertError: isAccessDenied,
assertError: require.Error,
},
{
name: "invalid resource type",
name: "get vm from identity",
requestTokenName: "test-token",
tokenSubscription: defaultSubscription,
tokenVMID: defaultVMID,
tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName),
tokenAzureResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName),
tokenSpec: types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Azure: &types.ProvisionTokenSpecV2Azure{
Expand All @@ -539,7 +538,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
},
verify: mockVerifyToken(nil),
certs: []*x509.Certificate{tlsConfig.Certificate},
assertError: isBadParameter,
assertError: require.NoError,
},
}

Expand Down

0 comments on commit 32df16f

Please sign in to comment.