From 5ee991132405befba6759b593c439df23a510f52 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 13 Dec 2024 13:48:13 -0800 Subject: [PATCH] Remove Azure client dependency from registerConfig --- lib/auth/bot_test.go | 18 +-- lib/auth/join_azure.go | 7 - lib/auth/join_azure_test.go | 210 ++++++---------------------- lib/cloud/azure/vm.go | 21 --- lib/srv/discovery/discovery_test.go | 4 - 5 files changed, 45 insertions(+), 215 deletions(-) diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index b71cadff189d0..c4239ee2d52eb 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -58,7 +58,6 @@ import ( "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/cloud/azure" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/fixtures" @@ -586,21 +585,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { require.NoError(t, err) require.NoError(t, a.UpsertToken(ctx, azureToken)) - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - rsID: { - ID: rsID, - Name: "test-vm", - Subscription: subID, - ResourceGroup: resourceGroup, - VMID: vmID, - }, - }, - } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - subID: vmClient, - }) - tlsConfig, err := fixtures.LocalTLSConfig() require.NoError(t, err) @@ -641,7 +625,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { AccessToken: accessToken, } return req, nil - }, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClientGetter(getVMClient)) + }, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil))) require.NoError(t, err) checkCertLoginIP(t, certs.TLS, remoteAddr) }) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 3064df724c567..c2a96372f611c 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -102,7 +102,6 @@ type azureRegisterConfig struct { clock clockwork.Clock certificateAuthorities []*x509.Certificate verify azureVerifyTokenFunc - getVMClient vmClientGetter } func azureVerifyFuncFromOIDCVerifier(cfg *oidc.Config) azureVerifyTokenFunc { @@ -155,12 +154,6 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { } cfg.certificateAuthorities = certs } - if cfg.getVMClient == nil { - cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { - client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil) - return client, trace.Wrap(err) - } - } return nil } diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 3ef911f7707aa..b10671517069f 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -38,7 +38,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/fixtures" ) @@ -54,43 +53,6 @@ func withVerifyFunc(verify azureVerifyTokenFunc) azureRegisterOption { } } -func withVMClientGetter(getVMClient vmClientGetter) azureRegisterOption { - return func(cfg *azureRegisterConfig) { - cfg.getVMClient = getVMClient - } -} - -type mockAzureVMClient struct { - azure.VirtualMachinesClient - vms map[string]*azure.VirtualMachine -} - -func (m *mockAzureVMClient) Get(_ context.Context, resourceID string) (*azure.VirtualMachine, error) { - vm, ok := m.vms[resourceID] - if !ok { - return nil, trace.NotFound("no vm with resource id %q", resourceID) - } - return vm, nil -} - -func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.VirtualMachine, error) { - for _, vm := range m.vms { - if vm.VMID == vmID { - return vm, nil - } - } - return nil, trace.NotFound("no vm with id %q", vmID) -} - -func makeVMClientGetter(clients map[string]*mockAzureVMClient) vmClientGetter { - return func(subscriptionID string, _ *azure.StaticCredential) (azure.VirtualMachinesClient, error) { - if client, ok := clients[subscriptionID]; ok { - return client, nil - } - return nil, trace.NotFound("no client for subscription %q", subscriptionID) - } -} - type azureChallengeResponseConfig struct { Challenge string } @@ -201,13 +163,12 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { defaultIdentityName := "test-id" defaultVMID := "my-vm-id" defaultVMResourceID := vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName) + defaultIdentityResourceID := identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName) tests := []struct { name string tokenManagedIdentityResourceID string tokenAzureResourceID string - tokenSubscription string - tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 challengeResponseOptions []azureChallengeResponseOption @@ -217,31 +178,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError require.ErrorAssertionFunc }{ { - name: "basic passing case", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenSpec: types.ProvisionTokenSpecV2{ - Roles: []types.SystemRole{types.RoleNode}, - Azure: &types.ProvisionTokenSpecV2Azure{ - Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ - { - Subscription: defaultSubscription, - ResourceGroups: []string{defaultResourceGroup}, - }, - }, - }, - JoinMethod: types.JoinMethodAzure, - }, - verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: require.NoError, - }, - { - name: "resource group is case insensitive", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, + name: "resource group is case insensitive", + requestTokenName: "test-token", + tokenManagedIdentityResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -259,10 +198,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.NoError, }, { - name: "wrong token", - requestTokenName: "wrong-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, + name: "wrong token", + requestTokenName: "wrong-token", + tokenManagedIdentityResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -279,10 +217,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "challenge response error", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, + name: "challenge response error", + requestTokenName: "test-token", + tokenManagedIdentityResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -300,51 +237,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isBadParameter, }, { - name: "wrong subscription", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenSpec: types.ProvisionTokenSpecV2{ - Roles: []types.SystemRole{types.RoleNode}, - Azure: &types.ProvisionTokenSpecV2Azure{ - Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ - { - Subscription: "alternate-subscription-id", - }, - }, - }, - JoinMethod: types.JoinMethodAzure, - }, - verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, - }, - { - name: "wrong resource group", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenSpec: types.ProvisionTokenSpecV2{ - Roles: []types.SystemRole{types.RoleNode}, - Azure: &types.ProvisionTokenSpecV2Azure{ - Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ - { - Subscription: defaultSubscription, - ResourceGroups: []string{"alternate-resource-group"}, - }, - }, - }, - JoinMethod: types.JoinMethodAzure, - }, - verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, - }, - { - name: "wrong challenge", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, + name: "wrong challenge", + requestTokenName: "test-token", + tokenManagedIdentityResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -364,10 +259,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "invalid signature", - requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, + name: "invalid signature", + requestTokenName: "test-token", + tokenManagedIdentityResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -386,8 +280,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "system-managed identity ok", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, @@ -408,8 +300,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "system-managed identity with wrong subscription", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, tokenManagedIdentityResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, @@ -430,8 +320,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "system-managed identity with wrong resource group", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, tokenManagedIdentityResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, @@ -452,10 +340,8 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "user-managed identity ok", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName), + tokenManagedIdentityResourceID: defaultIdentityResourceID, + tokenAzureResourceID: defaultVMResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -475,9 +361,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "user-managed identity with wrong subscription", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), + tokenManagedIdentityResourceID: defaultIdentityResourceID, tokenAzureResourceID: vmResourceID("alternate-subscription-id", defaultResourceGroup, defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, @@ -498,9 +382,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "user-managed identity with wrong resource group", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), + tokenManagedIdentityResourceID: defaultIdentityResourceID, tokenAzureResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultVMName), tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, @@ -521,10 +403,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { { name: "invalid resource type", requestTokenName: "test-token", - tokenSubscription: defaultSubscription, - tokenVMID: defaultVMID, - tokenManagedIdentityResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), - tokenAzureResourceID: identityResourceID(defaultSubscription, defaultResourceGroup, defaultIdentityName), + tokenManagedIdentityResourceID: defaultIdentityResourceID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ @@ -541,6 +420,25 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { certs: []*x509.Certificate{tlsConfig.Certificate}, assertError: isBadParameter, }, + { + name: "resource ID omitted", + requestTokenName: "test-token", + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + ResourceGroups: []string{defaultResourceGroup}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.Error, + }, } for _, tc := range tests { @@ -555,29 +453,9 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { require.NoError(t, a.DeleteToken(ctx, token.GetName())) }) - miRID := tc.tokenManagedIdentityResourceID - if miRID == "" { - miRID = vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName) - } - - accessToken, err := makeToken(miRID, tc.tokenAzureResourceID, a.clock.Now()) + accessToken, err := makeToken(tc.tokenManagedIdentityResourceID, tc.tokenAzureResourceID, a.clock.Now()) require.NoError(t, err) - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - defaultVMResourceID: { - ID: defaultVMResourceID, - Name: defaultVMName, - Subscription: defaultSubscription, - ResourceGroup: defaultResourceGroup, - VMID: defaultVMID, - }, - }, - } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - defaultSubscription: vmClient, - }) - _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { cfg := &azureChallengeResponseConfig{Challenge: challenge} for _, opt := range tc.challengeResponseOptions { @@ -586,8 +464,8 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { ad := attestedData{ Nonce: cfg.Challenge, - SubscriptionID: tc.tokenSubscription, - ID: tc.tokenVMID, + SubscriptionID: defaultSubscription, + ID: defaultVMID, } adBytes, err := json.Marshal(&ad) require.NoError(t, err) @@ -615,7 +493,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { AccessToken: accessToken, } return req, tc.challengeResponseErr - }, withCerts(tc.certs), withVerifyFunc(tc.verify), withVMClientGetter(getVMClient)) + }, withCerts(tc.certs), withVerifyFunc(tc.verify)) tc.assertError(t, err) }) } diff --git a/lib/cloud/azure/vm.go b/lib/cloud/azure/vm.go index 4503b6f2cb195..463ecd1c34d9c 100644 --- a/lib/cloud/azure/vm.go +++ b/lib/cloud/azure/vm.go @@ -45,8 +45,6 @@ type armCompute interface { type VirtualMachinesClient interface { // Get returns the virtual machine for the given resource ID. Get(ctx context.Context, resourceID string) (*VirtualMachine, error) - // GetByVMID returns the virtual machine for a given VM ID. - GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) // ListVirtualMachines gets all of the virtual machines in the given resource group. ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) } @@ -145,25 +143,6 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, return vm, trace.Wrap(err) } -// GetByVMID returns the virtual machine for a given VM ID. -func (c *vmClient) GetByVMID(ctx context.Context, vmID string) (*VirtualMachine, error) { - pager := newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{})) - for pager.more() { - res, err := pager.nextPage(ctx) - if err != nil { - return nil, trace.Wrap(ConvertResponseError(err)) - } - - for _, vm := range res { - if vm.Properties != nil && *vm.Properties.VMID == vmID { - result, err := parseVirtualMachine(vm) - return result, trace.Wrap(err) - } - } - } - return nil, trace.NotFound("no VM with ID %q", vmID) -} - type vmPager struct { more func() bool nextPage func(context.Context) ([]*armcompute.VirtualMachine, error) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index e2a187357d084..2fc7bc634d77b 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -2609,10 +2609,6 @@ func (m *mockAzureClient) Get(_ context.Context, _ string) (*azure.VirtualMachin return nil, nil } -func (m *mockAzureClient) GetByVMID(_ context.Context, _ string) (*azure.VirtualMachine, error) { - return nil, nil -} - func (m *mockAzureClient) ListVirtualMachines(_ context.Context, _ string) ([]*armcompute.VirtualMachine, error) { return m.vms, nil }